diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..71f6a2d --- /dev/null +++ b/.gitignore @@ -0,0 +1,25 @@ +# If you prefer the allow list template instead of the deny list, see community template: +# https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore +# +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, built with `go test -c` +*.test + +# Dependency directories +vendor/ + +# Go workspace file +go.work +go.work.sum + +# env file +.env + +build/ +.task/ diff --git a/.golangci.yaml b/.golangci.yaml new file mode 100644 index 0000000..8b080ac --- /dev/null +++ b/.golangci.yaml @@ -0,0 +1,48 @@ +version: "2" + +linters: + enable: + - errcheck + - govet + - ineffassign + - staticcheck + - unused + - copyloopvar + - dupl + - err113 + - errname + - exptostd + - fatcontext + - funlen + - gocognit + - goconst + - gocritic + - gocyclo + - godot + - godox + - gosec + - perfsprint + - testifylint + exclusions: + rules: + - path: '(.+)_test\.go' + linters: + - errcheck + - err113 + - gosec + - gocognit + - gocyclo + settings: + testifylint: + enable-all: true + disable: + - require-error + gocognit: + min-complexity: 10 + gocyclo: + min-complexity: 10 + gocritic: + enable-all: true + settings: + hugeParam: + sizeThreshold: 255 diff --git a/Taskfile.yml b/Taskfile.yml new file mode 100644 index 0000000..8f64c1f --- /dev/null +++ b/Taskfile.yml @@ -0,0 +1,65 @@ +# https://taskfile.dev + +version: '3' + +vars: + GOBIN_ENV: + sh: go env GOBIN + GOPATH_ENV: + sh: go env GOPATH + BIN: '{{if .GOBIN_ENV}}{{.GOBIN_ENV}}{{else}}{{.GOPATH_ENV}}/bin{{end}}' + +tasks: + default: + desc: fmt, lint, test + deps: + - fmt + - lint + - test + + fmt: + desc: Format go files + sources: + - "*.go" + cmds: + - go fmt ./... + + lint: + desc: Statically analyze code + sources: + - '*.go' + cmds: + - golangci-lint run + + test: + desc: Run all tests + sources: + - '*.go' + generates: + - build/cover.out + cmds: + - go test -race -cover -coverprofile build/cover.out . + + coverage-report: + desc: Build coverage report + deps: [test] + sources: + - build/cover.out + generates: + - build/cover.html + cmds: + - go tool cover -html=build/cover.out -o build/cover.html + + serve-report: + desc: Serve the coverage report + deps: [coverage-report] + sources: + - build/cover.html + cmds: + - ip addr list | grep inet + - python3 -m http.server -d build/ 3434 + + serve-docs: + desc: Serve the docs + cmds: + - godoc -http=0.0.0.0:3434 -play diff --git a/errgroup.go b/errgroup.go index 1d8cffa..313b360 100644 --- a/errgroup.go +++ b/errgroup.go @@ -13,6 +13,7 @@ import ( "context" "fmt" "sync" + "sync/atomic" ) type token struct{} @@ -27,15 +28,15 @@ type Group struct { wg sync.WaitGroup - sem chan token + sem atomic.Value errOnce sync.Once - err error + err atomic.Value } func (g *Group) done() { - if g.sem != nil { - <-g.sem + if sem := g.sema(); sem != nil { + <-sem } g.wg.Done() } @@ -50,14 +51,48 @@ func WithContext(ctx context.Context) (*Group, context.Context) { return &Group{cancel: cancel}, ctx } +type semContainer struct{ sem chan token } + +func (g *Group) sema() chan token { + v := g.sem.Load() + if v == nil { + return nil + } + return v.(semContainer).sem +} + +func (g *Group) setSema(ch chan token) { + g.sem.Store(semContainer{ch}) +} + +func (g *Group) error() error { + v := g.err.Load() + if v == nil { + return nil + } + return v.(error) +} + +func (g *Group) setError(err error) { + if err == nil { + return + } + g.errOnce.Do(func() { + g.err.Store(err) + if g.cancel != nil { + g.cancel(err) + } + }) +} + // Wait blocks until all function calls from the Go method have returned, then // returns the first non-nil error (if any) from them. func (g *Group) Wait() error { g.wg.Wait() if g.cancel != nil { - g.cancel(g.err) + g.cancel(g.error()) } - return g.err + return g.error() } // Go calls the given function in a new goroutine. @@ -70,8 +105,11 @@ func (g *Group) Wait() error { // cancel the associated Context, if any. The error will be returned // by Wait. func (g *Group) Go(f func() error) { - if g.sem != nil { - g.sem <- token{} + if g.error() != nil { + return + } + if sem := g.sema(); sem != nil { + sem <- token{} } g.wg.Add(1) @@ -90,14 +128,7 @@ func (g *Group) Go(f func() error) { // that prevents the Wait call from being reached. // See #53757, #74275, #74304, #74306. - if err := f(); err != nil { - g.errOnce.Do(func() { - g.err = err - if g.cancel != nil { - g.cancel(g.err) - } - }) - } + g.setError(f()) }() } @@ -106,9 +137,12 @@ func (g *Group) Go(f func() error) { // // The return value reports whether the goroutine was started. func (g *Group) TryGo(f func() error) bool { - if g.sem != nil { + if g.error() != nil { + return false + } + if sem := g.sema(); sem != nil { select { - case g.sem <- token{}: + case sem <- token{}: // Note: this allows barging iff channels in general allow barging. default: return false @@ -118,19 +152,17 @@ func (g *Group) TryGo(f func() error) bool { g.wg.Add(1) go func() { defer g.done() - - if err := f(); err != nil { - g.errOnce.Do(func() { - g.err = err - if g.cancel != nil { - g.cancel(g.err) - } - }) - } + g.setError(f()) }() return true } +type ErrgroupLimitError struct{ Size int } + +func (egerr *ErrgroupLimitError) Error() string { + return fmt.Sprintf("errgroup: modify limit while %v goroutines in the group are still active", egerr.Size) +} + // SetLimit limits the number of active goroutines in this group to at most n. // A negative value indicates no limit. // A limit of zero will prevent any new goroutines from being added. @@ -141,11 +173,13 @@ func (g *Group) TryGo(f func() error) bool { // The limit must not be modified while any goroutines in the group are active. func (g *Group) SetLimit(n int) { if n < 0 { - g.sem = nil + g.setSema(nil) return } - if len(g.sem) != 0 { - panic(fmt.Errorf("errgroup: modify limit while %v goroutines in the group are still active", len(g.sem))) + + if sem := g.sema(); sem != nil && len(sem) != 0 { + var err error = &ErrgroupLimitError{len(sem)} + panic(err) } - g.sem = make(chan token, n) + g.setSema(make(chan token, n)) } diff --git a/errgroup_test.go b/errgroup_test.go index 05e81e6..7e07ef7 100644 --- a/errgroup_test.go +++ b/errgroup_test.go @@ -43,8 +43,6 @@ func ExampleGroup_justErrors() { "http://www.somestupidname.com/", } for _, url := range urls { - // Launch a goroutine to fetch the URL. - url := url // https://golang.org/doc/faq#closures_and_goroutines g.Go(func() error { // Fetch the URL. resp, err := http.Get(url) @@ -71,7 +69,6 @@ func ExampleGroup_parallel() { searches := []Search{Web, Image, Video} results := make([]Result, len(searches)) for i, search := range searches { - i, search := i, search // https://golang.org/doc/faq#closures_and_goroutines g.Go(func() error { result, err := search(ctx, query) if err == nil { @@ -101,6 +98,38 @@ func ExampleGroup_parallel() { // video result for "golang" } +// FirstError demonstrates that g.Go becomes a no-op if a previous g.Go +// has returned an error. +func ExampleGroup_firstError() { + err1 := errors.New("errgroup_test: 1") + err2 := errors.New("errgroup_test: 2") + + g := new(errgroup.Group) + + ch := make(chan struct{}) + + g.Go(func() error { + fmt.Printf("Returning %s\n", err1) + ch <- struct{}{} + return err1 + }) + + <-ch + + g.Go(func() error { + // This should never run + fmt.Printf("Returning %s\n", err2) + return err2 + }) + + err := g.Wait() + fmt.Printf("Got %s\n", err) + + // Output: + // Returning errgroup_test: 1 + // Got errgroup_test: 1 +} + func TestZeroGroup(t *testing.T) { err1 := errors.New("errgroup_test: 1") err2 := errors.New("errgroup_test: 2") @@ -120,7 +149,6 @@ func TestZeroGroup(t *testing.T) { var firstErr error for i, err := range tc.errs { - err := err g.Go(func() error { return err }) if firstErr == nil && err != nil { @@ -153,7 +181,6 @@ func TestWithContext(t *testing.T) { g, ctx := errgroup.WithContext(context.Background()) for _, err := range tc.errs { - err := err g.Go(func() error { return err }) } @@ -228,6 +255,45 @@ func TestTryGo(t *testing.T) { g.Wait() } +func TestTryGoError(t *testing.T) { + g := &errgroup.Group{} + ch := make(chan struct{}) + var count int32 = 0 + err1 := errors.New("group_test: err1") + err2 := errors.New("group_test: err2") + + ran := g.TryGo(func() error { + atomic.AddInt32(&count, 1) + ch <- struct{}{} + return err1 + }) + + if !ran { + t.Error("TryGo should succeed but failed first run") + } + + <-ch + + ran = g.TryGo(func() error { + atomic.AddInt32(&count, 1) + return err2 + }) + + if ran { + t.Error("TryGo should failed but succeeded second run") + } + + if count != 1 { + t.Errorf("TryGo should have run 1 time, but ran %d times", count) + } + + err := g.Wait() + + if err != err1 { + t.Errorf("g.Wait() = %v, wanted %v", err, err1) + } +} + func TestGoLimit(t *testing.T) { const limit = 10 @@ -250,6 +316,111 @@ func TestGoLimit(t *testing.T) { } } +func TestSetLimit(t *testing.T) { + g := &errgroup.Group{} + ch := make(chan struct{}) + var count int32 = 0 + + g.SetLimit(0) + ran := g.TryGo(func() error { + atomic.AddInt32(&count, 1) + return nil + }) + if ran { + t.Fatal("TryGo should fail but succeeded first run") + } + + g.SetLimit(-1) + ran = g.TryGo(func() error { + atomic.AddInt32(&count, 1) + ch <- struct{}{} + return nil + }) + if !ran { + t.Fatal("TryGo should succeed but failed second run") + } + + <-ch + + g.SetLimit(1) + ran = g.TryGo(func() error { + atomic.AddInt32(&count, 1) + ch <- struct{}{} + return nil + }) + if !ran { + t.Fatal("TryGo should succeed but failed third run") + } + ran = g.TryGo(func() error { + atomic.AddInt32(&count, 1) + return nil + }) + if ran { + t.Fatal("TryGo should fail but succeeded fourth run") + } + + <-ch + + if count != 2 { + t.Errorf("TryGo should have run 2 times, but ran %d times", count) + } + + err := g.Wait() + + if err != nil { + t.Errorf("g.Wait() = %v, wanted %v", err, nil) + } + +} + +func TestLimitPanic(t *testing.T) { + g := &errgroup.Group{} + ch := make(chan struct{}) + + g.SetLimit(2) + + ran := g.TryGo(func() error { + ch <- struct{}{} + return nil + }) + if !ran { + t.Fatal("TryGo should succeed but failed first run") + } + + var err error + func() { + defer func() { + rec := recover() + if rec == nil { + t.Fatal("SetLimit should have panicked, but didn't") + } + + er, ok := rec.(error) + if !ok { + t.Fatalf("SetLimit should have panicked with an error, %T received", rec) + } + + err = er + }() + g.SetLimit(5) + }() + + <-ch + + glimErr := &errgroup.ErrgroupLimitError{} + if !errors.As(err, &glimErr) { + t.Fatalf("panicked error should have been ErrgroupLimitError. %T received", err) + } + + if glimErr.Size != 1 { + t.Fatalf("ErrgroupLimitError should have had a size of 1. Got %d", glimErr.Size) + } + expErrMsg := "errgroup: modify limit while 1 goroutines in the group are still active" + if errStr := glimErr.Error(); errStr != expErrMsg { + t.Fatalf("error message should have been %s. Got %s", expErrMsg, errStr) + } +} + func TestCancelCause(t *testing.T) { errDoom := errors.New("group_test: doomed") @@ -261,13 +432,13 @@ func TestCancelCause(t *testing.T) { {errs: []error{nil}, want: nil}, {errs: []error{errDoom}, want: errDoom}, {errs: []error{errDoom, nil}, want: errDoom}, + {errs: []error{nil, errDoom}, want: errDoom}, } for _, tc := range cases { g, ctx := errgroup.WithContext(context.Background()) for _, err := range tc.errs { - err := err g.TryGo(func() error { return err }) }