diff --git a/CHANGELOG.md b/CHANGELOG.md index 9d00936..7f6acd9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,18 @@ # Changelog +## v1.0.0 - 2025-11-14 + +### Added +- Refactored WaitErr into an interface and added a New constructor. +- Added WithContext function for context cancellation. +- Added CONTRIBUTING.md for human contributors. + +### Changed +- Updated README.md and example_test.go to reflect the new interface and WithContext function. +- Updated README.md to reference CONTRIBUTING.md. + + + ## v0.9.0 - 2025-11-13 ### Added diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..4b8978c --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,34 @@ +# Contributing to waiterr + +We welcome contributions to the `waiterr` project! Please take a moment to review these guidelines before submitting your contributions. + +## Reporting Bugs and Suggesting Features + +If you encounter a bug or have a feature request, please report it on our [Codeberg repository](https://codeberg.org/danjones000/waiterr/issues). + +## Git Flow Guidelines + +We follow a Git Flow branching model. + +* **`develop` branch**: This is our main integration branch for new features and bug fixes. +* **`stable` branch**: This branch contains the latest production-ready code. + +### Making Changes + +1. **Branching**: + * For new features or regular bug fixes, create a new branch from `develop` (e.g., `feat/your-feature-name` or `bug/your-bug-fix`). + * For urgent hotfixes addressing critical issues in `stable`, create a branch directly from `stable` (e.g., `hot/your-hotfix-name`). + +2. **Pull Requests (PRs)**: + * All new features and regular bug fixes should be submitted as Pull Requests targeting the `develop` branch. + * Hotfixes should be submitted as Pull Requests targeting the `stable` branch directly. After a hotfix is merged into `stable`, it must also be merged back into `develop`. + +3. **Commit Messages**: + * It's not *required* that you follow the [Gitmoji convention](https://gitmoji.dev/) for your commit messages, but it would make me happy if you did. 😏 + * Write clear, concise, and descriptive commit messages that explain *what* changed and *why*. + +## Code Style + +Please ensure your code adheres to the existing Go code style and formatting conventions used in the project. Run `go fmt ./...` and `go mod tidy` before submitting your changes. + +Thank you for contributing! diff --git a/README.md b/README.md index b503e38..d05ac57 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ import ( ) func main() { - we := new(waiterr.WaitErr) + we := waiterr.New() we.Go(func() error { time.Sleep(100 * time.Millisecond) @@ -60,7 +60,7 @@ func main() { } // You can also get the first error immediately - we2 := new(waiterr.WaitErr) + we2 := waiterr.New() we2.Go(func() error { time.Sleep(100 * time.Millisecond) return errors.New("first error from we2") @@ -87,9 +87,45 @@ func main() { } ``` +### Using WithContext + +```go +package main + +import ( + "context" + "fmt" + "time" + + "codeberg.org/danjones000/waiterr" +) + +func main() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + we, ctx := waiterr.WithContext(ctx) + + we.Go(func() error { + select { + case <-time.After(100 * time.Millisecond): + fmt.Println("Task completed") + return nil + case <-ctx.Done(): + fmt.Println("Task cancelled") + return ctx.Err() + } + }) + + _ = we.Wait() + // Output: + // Task completed +} +``` + ## Contributing -Please refer to the [AGENTS.md](AGENTS.md) file for guidelines on contributing to this project, including code style, commit messages, and Git workflow. +Please refer to the [CONTRIBUTING.md](CONTRIBUTING.md) file for guidelines on contributing to this project, including code style, commit messages, and Git workflow. ## License diff --git a/example_test.go b/example_test.go index 3246ddd..00320d5 100644 --- a/example_test.go +++ b/example_test.go @@ -1,6 +1,7 @@ package waiterr_test import ( + "context" "errors" "fmt" "time" @@ -9,8 +10,7 @@ import ( ) func Example() { - we := new(waiterr.WaitErr) - + we := waiterr.New() we.Go(func() error { time.Sleep(100 * time.Millisecond) fmt.Println("Goroutine 1 finished") @@ -42,8 +42,8 @@ func Example() { // something went wrong in goroutine 2 } -func ExampleWaitErr_WaitForError() { - we := new(waiterr.WaitErr) +func Example_waitForError() { + we := waiterr.New() we.Go(func() error { time.Sleep(100 * time.Millisecond) return errors.New("first error from we") @@ -62,8 +62,8 @@ func ExampleWaitErr_WaitForError() { } -func ExampleWaitErr_Unwrap() { - we := new(waiterr.WaitErr) +func Example_unwrap() { + we := waiterr.New() we.Go(func() error { time.Sleep(100 * time.Millisecond) return errors.New("first error from we") @@ -88,3 +88,50 @@ func ExampleWaitErr_Unwrap() { // second error from we // first error from we } + +func ExampleWithContext() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + we, ctx := waiterr.WithContext(ctx) + + we.Go(func() error { + select { + case <-time.After(100 * time.Millisecond): + fmt.Println("Goroutine 1 finished") + return nil + case <-ctx.Done(): + return ctx.Err() + } + }) + + we.Go(func() error { + select { + case <-time.After(50 * time.Millisecond): + fmt.Println("Goroutine 2 finished with an error") + return errors.New("something went wrong in goroutine 2") + case <-ctx.Done(): + return ctx.Err() + } + }) + + we.Go(func() error { + select { + case <-time.After(150 * time.Millisecond): + fmt.Println("Goroutine 3 finished") + return nil + case <-ctx.Done(): + return ctx.Err() + } + }) + + if err := we.Wait(); err != nil { + fmt.Printf("All goroutines finished. Combined error: %s\n", err) + } + + // Output: + // Goroutine 2 finished with an error + // All goroutines finished. Combined error: something went wrong in goroutine 2 + // context canceled + // context canceled +} diff --git a/waiterr.go b/waiterr.go index 0ba8821..66a64b9 100644 --- a/waiterr.go +++ b/waiterr.go @@ -1,27 +1,59 @@ package waiterr import ( + "context" "errors" "sync" ) +func New() WaitErr { + var we waitErr + we.errCh = make(chan error, 1) + we.cancel = func(error) {} + + return &we +} + +func WithContext(ctx context.Context) (WaitErr, context.Context) { + var we waitErr + we.errCh = make(chan error, 1) + cCtx, canc := context.WithCancelCause(ctx) + we.cancel = canc + + return &we, cCtx +} + +// WaitErr provides a way to run multiple goroutines and wait for their completion, +// collecting any errors they return. +type WaitErr interface { + // Go runs f in its own goroutine. When f returns, its error is stored, and returned + // with [WaitErr.Wait]. + Go(f func() error) + // WaitForError waits for the first error to be returned by one of our go routines, and immediately returns + // with that error. If all functions return successfully, a nil is returned. + WaitForError() error + // Wait for all current goroutines to finish. Return an error that combines all errors returned + // in the group so far (if any). + Wait() error + // Unwrap returns all non-nil errors returned by our functions. + // If no errors were returned, or all errors are nil, it returns nil. + Unwrap() []error +} + // WaitErr wraps a [sync.WaitGroup] with error handling. -type WaitErr struct { - wg sync.WaitGroup - errs []error - mut sync.RWMutex - firstErr error - firstErrOnce sync.Once - errCh chan error // Buffered channel of size 1 - initErrChOnce sync.Once +type waitErr struct { + wg sync.WaitGroup + errs []error + mut sync.RWMutex + firstErr error + firstErrOnce sync.Once + errCh chan error // Buffered channel of size 1 + cancel context.CancelCauseFunc } // Go runs f in its own goroutine. When f returns, its error is stored, and returned // with [WaitErr.Wait]. -func (we *WaitErr) Go(f func() error) { - we.initErrChOnce.Do(func() { - we.errCh = make(chan error, 1) - }) +func (we *waitErr) Go(f func() error) { wrap := func() { err := f() @@ -29,6 +61,7 @@ func (we *WaitErr) Go(f func() error) { we.firstErrOnce.Do(func() { we.mut.Lock() // Acquire lock before writing to firstErr we.firstErr = err + we.cancel(err) we.mut.Unlock() // Release lock after writing // Non-blocking send to errCh @@ -48,10 +81,7 @@ func (we *WaitErr) Go(f func() error) { // WaitForError waits for the first error to be returned by one of our go routines, and immediately returns // with that error. If all functions return successfully, a nil is returned. It will panic if called before Go. -func (we *WaitErr) WaitForError() error { - if we.errCh == nil { - panic("WaitForError called before Go") - } +func (we *waitErr) WaitForError() error { // Check if an error has already been set we.mut.RLock() if we.firstErr != nil { @@ -82,16 +112,18 @@ func (we *WaitErr) WaitForError() error { // Wait for all current goroutines to finish. Return an error that combines all errors returned // in the group so far (if any). -func (we *WaitErr) Wait() error { +func (we *waitErr) Wait() error { we.wg.Wait() we.mut.RLock() defer we.mut.RUnlock() - return errors.Join(we.errs...) + ret := errors.Join(we.errs...) + we.cancel(ret) + return ret } // Unwrap returns all non-nil errors returned by our functions. // If no errors were returned, or all errors are nil, it returns nil. -func (we *WaitErr) Unwrap() []error { +func (we *waitErr) Unwrap() []error { errs := make([]error, 0, len(we.errs)) for _, e := range we.errs { if e != nil { diff --git a/waiterr_test.go b/waiterr_test.go index f2ba7b2..fb4818d 100644 --- a/waiterr_test.go +++ b/waiterr_test.go @@ -1,6 +1,7 @@ package waiterr_test import ( + "context" "errors" "testing" "testing/synctest" @@ -11,7 +12,7 @@ import ( ) func TestGo(t *testing.T) { - we := new(waiterr.WaitErr) + we := waiterr.New() err := errors.New("uh-oh") var run bool we.Go(func() error { @@ -24,7 +25,7 @@ func TestGo(t *testing.T) { } func TestWait(t *testing.T) { - we := new(waiterr.WaitErr) + we := waiterr.New() er1 := errors.New("uh-oh") er2 := errors.New("oops") we.Go(func() error { return er1 }) @@ -45,7 +46,7 @@ func TestWait(t *testing.T) { func TestWaitForError(tt *testing.T) { tt.Run("first error", func(t *testing.T) { - we := new(waiterr.WaitErr) + we := waiterr.New() er1 := errors.New("uh-oh") er2 := errors.New("oops") we.Go(func() error { return nil }) @@ -58,7 +59,7 @@ func TestWaitForError(tt *testing.T) { }) tt.Run("no error", func(t *testing.T) { - we := new(waiterr.WaitErr) + we := waiterr.New() we.Go(func() error { return nil }) we.Go(func() error { return nil }) we.Go(func() error { return nil }) @@ -67,19 +68,8 @@ func TestWaitForError(tt *testing.T) { be.Err(t, err, nil) }) - tt.Run("panic", func(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Errorf("The code did not panic") - } - }() - - we := new(waiterr.WaitErr) - _ = we.WaitForError() - }) - tt.Run("first error set", func(tt2 *testing.T) { - we := new(waiterr.WaitErr) + we := waiterr.New() expectedErr := errors.New("pre-set error") synctest.Test(tt2, func(t *testing.T) { @@ -100,7 +90,7 @@ func TestWaitForError(tt *testing.T) { func TestUnwrap(tt *testing.T) { tt.Run("two errors", func(t *testing.T) { - we := new(waiterr.WaitErr) + we := waiterr.New() er1 := errors.New("error one") er2 := errors.New("error two") @@ -118,10 +108,36 @@ func TestUnwrap(tt *testing.T) { }) tt.Run("no errors", func(t *testing.T) { - weNoErr := new(waiterr.WaitErr) + weNoErr := waiterr.New() weNoErr.Go(func() error { return nil }) weNoErr.Go(func() error { return nil }) _ = weNoErr.Wait() be.Equal(t, weNoErr.Unwrap(), nil) }) } + +func TestWithContext(tt *testing.T) { + tt.Run("with error", func(tt2 *testing.T) { + er1 := errors.New("uh-oh") + er2 := errors.New("oops") + synctest.Test(tt2, func(t *testing.T) { + we, ctx := waiterr.WithContext(t.Context()) + we.Go(func() error { return er1 }) + synctest.Wait() // Ensure it finishes first + we.Go(func() error { return er2 }) + + er := context.Cause(ctx) + be.Err(t, er, er1) + be.True(t, !errors.Is(er, er2)) + }) + }) + + tt.Run("no error", func(t *testing.T) { + we, ctx := waiterr.WithContext(t.Context()) + we.Go(func() error { return nil }) + we.Go(func() error { return nil }) + er := we.Wait() + be.Err(t, er, nil) + be.Err(t, context.Cause(ctx), context.Canceled) + }) +}