✨ Add WithContext function for cancellation
This commit is contained in:
parent
edc34062e9
commit
09feb51213
2 changed files with 43 additions and 1 deletions
17
waiterr.go
17
waiterr.go
|
|
@ -1,6 +1,7 @@
|
||||||
package waiterr
|
package waiterr
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
@ -8,10 +9,20 @@ import (
|
||||||
func New() WaitErr {
|
func New() WaitErr {
|
||||||
var we waitErr
|
var we waitErr
|
||||||
we.errCh = make(chan error, 1)
|
we.errCh = make(chan error, 1)
|
||||||
|
we.cancel = func(error) {}
|
||||||
|
|
||||||
return &we
|
return &we
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithContext(ctx context.Context) (WaitErr, context.Context) {
|
||||||
|
var we waitErr
|
||||||
|
we.errCh = make(chan error, 1)
|
||||||
|
cCtx, canc := context.WithCancelCause(ctx)
|
||||||
|
we.cancel = canc
|
||||||
|
|
||||||
|
return &we, cCtx
|
||||||
|
}
|
||||||
|
|
||||||
// WaitErr provides a way to run multiple goroutines and wait for their completion,
|
// WaitErr provides a way to run multiple goroutines and wait for their completion,
|
||||||
// collecting any errors they return.
|
// collecting any errors they return.
|
||||||
type WaitErr interface {
|
type WaitErr interface {
|
||||||
|
|
@ -37,6 +48,7 @@ type waitErr struct {
|
||||||
firstErr error
|
firstErr error
|
||||||
firstErrOnce sync.Once
|
firstErrOnce sync.Once
|
||||||
errCh chan error // Buffered channel of size 1
|
errCh chan error // Buffered channel of size 1
|
||||||
|
cancel context.CancelCauseFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
// Go runs f in its own goroutine. When f returns, its error is stored, and returned
|
// Go runs f in its own goroutine. When f returns, its error is stored, and returned
|
||||||
|
|
@ -49,6 +61,7 @@ func (we *waitErr) Go(f func() error) {
|
||||||
we.firstErrOnce.Do(func() {
|
we.firstErrOnce.Do(func() {
|
||||||
we.mut.Lock() // Acquire lock before writing to firstErr
|
we.mut.Lock() // Acquire lock before writing to firstErr
|
||||||
we.firstErr = err
|
we.firstErr = err
|
||||||
|
we.cancel(err)
|
||||||
we.mut.Unlock() // Release lock after writing
|
we.mut.Unlock() // Release lock after writing
|
||||||
|
|
||||||
// Non-blocking send to errCh
|
// Non-blocking send to errCh
|
||||||
|
|
@ -103,7 +116,9 @@ func (we *waitErr) Wait() error {
|
||||||
we.wg.Wait()
|
we.wg.Wait()
|
||||||
we.mut.RLock()
|
we.mut.RLock()
|
||||||
defer we.mut.RUnlock()
|
defer we.mut.RUnlock()
|
||||||
return errors.Join(we.errs...)
|
ret := errors.Join(we.errs...)
|
||||||
|
we.cancel(ret)
|
||||||
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
// Unwrap returns all non-nil errors returned by our functions.
|
// Unwrap returns all non-nil errors returned by our functions.
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
package waiterr_test
|
package waiterr_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
"testing/synctest"
|
"testing/synctest"
|
||||||
|
|
@ -114,3 +115,29 @@ func TestUnwrap(tt *testing.T) {
|
||||||
be.Equal(t, weNoErr.Unwrap(), nil)
|
be.Equal(t, weNoErr.Unwrap(), nil)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWithContext(tt *testing.T) {
|
||||||
|
tt.Run("with error", func(tt2 *testing.T) {
|
||||||
|
er1 := errors.New("uh-oh")
|
||||||
|
er2 := errors.New("oops")
|
||||||
|
synctest.Test(tt2, func(t *testing.T) {
|
||||||
|
we, ctx := waiterr.WithContext(t.Context())
|
||||||
|
we.Go(func() error { return er1 })
|
||||||
|
synctest.Wait() // Ensure it finishes first
|
||||||
|
we.Go(func() error { return er2 })
|
||||||
|
|
||||||
|
er := context.Cause(ctx)
|
||||||
|
be.Err(t, er, er1)
|
||||||
|
be.True(t, !errors.Is(er, er2))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
tt.Run("no error", func(t *testing.T) {
|
||||||
|
we, ctx := waiterr.WithContext(t.Context())
|
||||||
|
we.Go(func() error { return nil })
|
||||||
|
we.Go(func() error { return nil })
|
||||||
|
er := we.Wait()
|
||||||
|
be.Err(t, er, nil)
|
||||||
|
be.Err(t, context.Cause(ctx), context.Canceled)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue