Add WithContext function for cancellation

This commit is contained in:
Dan Jones 2025-11-14 14:53:36 -06:00
commit 09feb51213
2 changed files with 43 additions and 1 deletions

View file

@ -1,6 +1,7 @@
package waiterr
import (
"context"
"errors"
"sync"
)
@ -8,10 +9,20 @@ import (
func New() WaitErr {
var we waitErr
we.errCh = make(chan error, 1)
we.cancel = func(error) {}
return &we
}
func WithContext(ctx context.Context) (WaitErr, context.Context) {
var we waitErr
we.errCh = make(chan error, 1)
cCtx, canc := context.WithCancelCause(ctx)
we.cancel = canc
return &we, cCtx
}
// WaitErr provides a way to run multiple goroutines and wait for their completion,
// collecting any errors they return.
type WaitErr interface {
@ -37,6 +48,7 @@ type waitErr struct {
firstErr error
firstErrOnce sync.Once
errCh chan error // Buffered channel of size 1
cancel context.CancelCauseFunc
}
// Go runs f in its own goroutine. When f returns, its error is stored, and returned
@ -49,6 +61,7 @@ func (we *waitErr) Go(f func() error) {
we.firstErrOnce.Do(func() {
we.mut.Lock() // Acquire lock before writing to firstErr
we.firstErr = err
we.cancel(err)
we.mut.Unlock() // Release lock after writing
// Non-blocking send to errCh
@ -103,7 +116,9 @@ func (we *waitErr) Wait() error {
we.wg.Wait()
we.mut.RLock()
defer we.mut.RUnlock()
return errors.Join(we.errs...)
ret := errors.Join(we.errs...)
we.cancel(ret)
return ret
}
// Unwrap returns all non-nil errors returned by our functions.

View file

@ -1,6 +1,7 @@
package waiterr_test
import (
"context"
"errors"
"testing"
"testing/synctest"
@ -114,3 +115,29 @@ func TestUnwrap(tt *testing.T) {
be.Equal(t, weNoErr.Unwrap(), nil)
})
}
func TestWithContext(tt *testing.T) {
tt.Run("with error", func(tt2 *testing.T) {
er1 := errors.New("uh-oh")
er2 := errors.New("oops")
synctest.Test(tt2, func(t *testing.T) {
we, ctx := waiterr.WithContext(t.Context())
we.Go(func() error { return er1 })
synctest.Wait() // Ensure it finishes first
we.Go(func() error { return er2 })
er := context.Cause(ctx)
be.Err(t, er, er1)
be.True(t, !errors.Is(er, er2))
})
})
tt.Run("no error", func(t *testing.T) {
we, ctx := waiterr.WithContext(t.Context())
we.Go(func() error { return nil })
we.Go(func() error { return nil })
er := we.Wait()
be.Err(t, er, nil)
be.Err(t, context.Cause(ctx), context.Canceled)
})
}