diff --git a/waiterr.go b/waiterr.go index 94d6cd6..66a64b9 100644 --- a/waiterr.go +++ b/waiterr.go @@ -1,6 +1,7 @@ package waiterr import ( + "context" "errors" "sync" ) @@ -8,10 +9,20 @@ import ( func New() WaitErr { var we waitErr we.errCh = make(chan error, 1) + we.cancel = func(error) {} return &we } +func WithContext(ctx context.Context) (WaitErr, context.Context) { + var we waitErr + we.errCh = make(chan error, 1) + cCtx, canc := context.WithCancelCause(ctx) + we.cancel = canc + + return &we, cCtx +} + // WaitErr provides a way to run multiple goroutines and wait for their completion, // collecting any errors they return. type WaitErr interface { @@ -37,6 +48,7 @@ type waitErr struct { firstErr error firstErrOnce sync.Once errCh chan error // Buffered channel of size 1 + cancel context.CancelCauseFunc } // Go runs f in its own goroutine. When f returns, its error is stored, and returned @@ -49,6 +61,7 @@ func (we *waitErr) Go(f func() error) { we.firstErrOnce.Do(func() { we.mut.Lock() // Acquire lock before writing to firstErr we.firstErr = err + we.cancel(err) we.mut.Unlock() // Release lock after writing // Non-blocking send to errCh @@ -103,7 +116,9 @@ func (we *waitErr) Wait() error { we.wg.Wait() we.mut.RLock() defer we.mut.RUnlock() - return errors.Join(we.errs...) + ret := errors.Join(we.errs...) + we.cancel(ret) + return ret } // Unwrap returns all non-nil errors returned by our functions. diff --git a/waiterr_test.go b/waiterr_test.go index c173675..fb4818d 100644 --- a/waiterr_test.go +++ b/waiterr_test.go @@ -1,6 +1,7 @@ package waiterr_test import ( + "context" "errors" "testing" "testing/synctest" @@ -114,3 +115,29 @@ func TestUnwrap(tt *testing.T) { be.Equal(t, weNoErr.Unwrap(), nil) }) } + +func TestWithContext(tt *testing.T) { + tt.Run("with error", func(tt2 *testing.T) { + er1 := errors.New("uh-oh") + er2 := errors.New("oops") + synctest.Test(tt2, func(t *testing.T) { + we, ctx := waiterr.WithContext(t.Context()) + we.Go(func() error { return er1 }) + synctest.Wait() // Ensure it finishes first + we.Go(func() error { return er2 }) + + er := context.Cause(ctx) + be.Err(t, er, er1) + be.True(t, !errors.Is(er, er2)) + }) + }) + + tt.Run("no error", func(t *testing.T) { + we, ctx := waiterr.WithContext(t.Context()) + we.Go(func() error { return nil }) + we.Go(func() error { return nil }) + er := we.Wait() + be.Err(t, er, nil) + be.Err(t, context.Cause(ctx), context.Canceled) + }) +}