✨ Refactor: Convert WaitErr to an interface and add New constructor
This commit is contained in:
parent
a68fc26481
commit
edc34062e9
4 changed files with 47 additions and 42 deletions
|
|
@ -33,7 +33,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
we := new(waiterr.WaitErr)
|
we := waiterr.New()
|
||||||
|
|
||||||
we.Go(func() error {
|
we.Go(func() error {
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
@ -60,7 +60,7 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// You can also get the first error immediately
|
// You can also get the first error immediately
|
||||||
we2 := new(waiterr.WaitErr)
|
we2 := waiterr.New()
|
||||||
we2.Go(func() error {
|
we2.Go(func() error {
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
return errors.New("first error from we2")
|
return errors.New("first error from we2")
|
||||||
|
|
|
||||||
|
|
@ -9,8 +9,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func Example() {
|
func Example() {
|
||||||
we := new(waiterr.WaitErr)
|
we := waiterr.New()
|
||||||
|
|
||||||
we.Go(func() error {
|
we.Go(func() error {
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
fmt.Println("Goroutine 1 finished")
|
fmt.Println("Goroutine 1 finished")
|
||||||
|
|
@ -43,7 +42,7 @@ func Example() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func ExampleWaitErr_WaitForError() {
|
func ExampleWaitErr_WaitForError() {
|
||||||
we := new(waiterr.WaitErr)
|
we := waiterr.New()
|
||||||
we.Go(func() error {
|
we.Go(func() error {
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
return errors.New("first error from we")
|
return errors.New("first error from we")
|
||||||
|
|
@ -63,7 +62,7 @@ func ExampleWaitErr_WaitForError() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func ExampleWaitErr_Unwrap() {
|
func ExampleWaitErr_Unwrap() {
|
||||||
we := new(waiterr.WaitErr)
|
we := waiterr.New()
|
||||||
we.Go(func() error {
|
we.Go(func() error {
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
return errors.New("first error from we")
|
return errors.New("first error from we")
|
||||||
|
|
|
||||||
53
waiterr.go
53
waiterr.go
|
|
@ -5,23 +5,43 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func New() WaitErr {
|
||||||
|
var we waitErr
|
||||||
|
we.errCh = make(chan error, 1)
|
||||||
|
|
||||||
|
return &we
|
||||||
|
}
|
||||||
|
|
||||||
|
// WaitErr provides a way to run multiple goroutines and wait for their completion,
|
||||||
|
// collecting any errors they return.
|
||||||
|
type WaitErr interface {
|
||||||
|
// Go runs f in its own goroutine. When f returns, its error is stored, and returned
|
||||||
|
// with [WaitErr.Wait].
|
||||||
|
Go(f func() error)
|
||||||
|
// WaitForError waits for the first error to be returned by one of our go routines, and immediately returns
|
||||||
|
// with that error. If all functions return successfully, a nil is returned.
|
||||||
|
WaitForError() error
|
||||||
|
// Wait for all current goroutines to finish. Return an error that combines all errors returned
|
||||||
|
// in the group so far (if any).
|
||||||
|
Wait() error
|
||||||
|
// Unwrap returns all non-nil errors returned by our functions.
|
||||||
|
// If no errors were returned, or all errors are nil, it returns nil.
|
||||||
|
Unwrap() []error
|
||||||
|
}
|
||||||
|
|
||||||
// WaitErr wraps a [sync.WaitGroup] with error handling.
|
// WaitErr wraps a [sync.WaitGroup] with error handling.
|
||||||
type WaitErr struct {
|
type waitErr struct {
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
errs []error
|
errs []error
|
||||||
mut sync.RWMutex
|
mut sync.RWMutex
|
||||||
firstErr error
|
firstErr error
|
||||||
firstErrOnce sync.Once
|
firstErrOnce sync.Once
|
||||||
errCh chan error // Buffered channel of size 1
|
errCh chan error // Buffered channel of size 1
|
||||||
initErrChOnce sync.Once
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Go runs f in its own goroutine. When f returns, its error is stored, and returned
|
// Go runs f in its own goroutine. When f returns, its error is stored, and returned
|
||||||
// with [WaitErr.Wait].
|
// with [WaitErr.Wait].
|
||||||
func (we *WaitErr) Go(f func() error) {
|
func (we *waitErr) Go(f func() error) {
|
||||||
we.initErrChOnce.Do(func() {
|
|
||||||
we.errCh = make(chan error, 1)
|
|
||||||
})
|
|
||||||
wrap := func() {
|
wrap := func() {
|
||||||
err := f()
|
err := f()
|
||||||
|
|
||||||
|
|
@ -48,10 +68,7 @@ func (we *WaitErr) Go(f func() error) {
|
||||||
|
|
||||||
// WaitForError waits for the first error to be returned by one of our go routines, and immediately returns
|
// WaitForError waits for the first error to be returned by one of our go routines, and immediately returns
|
||||||
// with that error. If all functions return successfully, a nil is returned. It will panic if called before Go.
|
// with that error. If all functions return successfully, a nil is returned. It will panic if called before Go.
|
||||||
func (we *WaitErr) WaitForError() error {
|
func (we *waitErr) WaitForError() error {
|
||||||
if we.errCh == nil {
|
|
||||||
panic("WaitForError called before Go")
|
|
||||||
}
|
|
||||||
// Check if an error has already been set
|
// Check if an error has already been set
|
||||||
we.mut.RLock()
|
we.mut.RLock()
|
||||||
if we.firstErr != nil {
|
if we.firstErr != nil {
|
||||||
|
|
@ -82,7 +99,7 @@ func (we *WaitErr) WaitForError() error {
|
||||||
|
|
||||||
// Wait for all current goroutines to finish. Return an error that combines all errors returned
|
// Wait for all current goroutines to finish. Return an error that combines all errors returned
|
||||||
// in the group so far (if any).
|
// in the group so far (if any).
|
||||||
func (we *WaitErr) Wait() error {
|
func (we *waitErr) Wait() error {
|
||||||
we.wg.Wait()
|
we.wg.Wait()
|
||||||
we.mut.RLock()
|
we.mut.RLock()
|
||||||
defer we.mut.RUnlock()
|
defer we.mut.RUnlock()
|
||||||
|
|
@ -91,7 +108,7 @@ func (we *WaitErr) Wait() error {
|
||||||
|
|
||||||
// Unwrap returns all non-nil errors returned by our functions.
|
// Unwrap returns all non-nil errors returned by our functions.
|
||||||
// If no errors were returned, or all errors are nil, it returns nil.
|
// If no errors were returned, or all errors are nil, it returns nil.
|
||||||
func (we *WaitErr) Unwrap() []error {
|
func (we *waitErr) Unwrap() []error {
|
||||||
errs := make([]error, 0, len(we.errs))
|
errs := make([]error, 0, len(we.errs))
|
||||||
for _, e := range we.errs {
|
for _, e := range we.errs {
|
||||||
if e != nil {
|
if e != nil {
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestGo(t *testing.T) {
|
func TestGo(t *testing.T) {
|
||||||
we := new(waiterr.WaitErr)
|
we := waiterr.New()
|
||||||
err := errors.New("uh-oh")
|
err := errors.New("uh-oh")
|
||||||
var run bool
|
var run bool
|
||||||
we.Go(func() error {
|
we.Go(func() error {
|
||||||
|
|
@ -24,7 +24,7 @@ func TestGo(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWait(t *testing.T) {
|
func TestWait(t *testing.T) {
|
||||||
we := new(waiterr.WaitErr)
|
we := waiterr.New()
|
||||||
er1 := errors.New("uh-oh")
|
er1 := errors.New("uh-oh")
|
||||||
er2 := errors.New("oops")
|
er2 := errors.New("oops")
|
||||||
we.Go(func() error { return er1 })
|
we.Go(func() error { return er1 })
|
||||||
|
|
@ -45,7 +45,7 @@ func TestWait(t *testing.T) {
|
||||||
|
|
||||||
func TestWaitForError(tt *testing.T) {
|
func TestWaitForError(tt *testing.T) {
|
||||||
tt.Run("first error", func(t *testing.T) {
|
tt.Run("first error", func(t *testing.T) {
|
||||||
we := new(waiterr.WaitErr)
|
we := waiterr.New()
|
||||||
er1 := errors.New("uh-oh")
|
er1 := errors.New("uh-oh")
|
||||||
er2 := errors.New("oops")
|
er2 := errors.New("oops")
|
||||||
we.Go(func() error { return nil })
|
we.Go(func() error { return nil })
|
||||||
|
|
@ -58,7 +58,7 @@ func TestWaitForError(tt *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
tt.Run("no error", func(t *testing.T) {
|
tt.Run("no error", func(t *testing.T) {
|
||||||
we := new(waiterr.WaitErr)
|
we := waiterr.New()
|
||||||
we.Go(func() error { return nil })
|
we.Go(func() error { return nil })
|
||||||
we.Go(func() error { return nil })
|
we.Go(func() error { return nil })
|
||||||
we.Go(func() error { return nil })
|
we.Go(func() error { return nil })
|
||||||
|
|
@ -67,19 +67,8 @@ func TestWaitForError(tt *testing.T) {
|
||||||
be.Err(t, err, nil)
|
be.Err(t, err, nil)
|
||||||
})
|
})
|
||||||
|
|
||||||
tt.Run("panic", func(t *testing.T) {
|
|
||||||
defer func() {
|
|
||||||
if r := recover(); r == nil {
|
|
||||||
t.Errorf("The code did not panic")
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
we := new(waiterr.WaitErr)
|
|
||||||
_ = we.WaitForError()
|
|
||||||
})
|
|
||||||
|
|
||||||
tt.Run("first error set", func(tt2 *testing.T) {
|
tt.Run("first error set", func(tt2 *testing.T) {
|
||||||
we := new(waiterr.WaitErr)
|
we := waiterr.New()
|
||||||
expectedErr := errors.New("pre-set error")
|
expectedErr := errors.New("pre-set error")
|
||||||
|
|
||||||
synctest.Test(tt2, func(t *testing.T) {
|
synctest.Test(tt2, func(t *testing.T) {
|
||||||
|
|
@ -100,7 +89,7 @@ func TestWaitForError(tt *testing.T) {
|
||||||
|
|
||||||
func TestUnwrap(tt *testing.T) {
|
func TestUnwrap(tt *testing.T) {
|
||||||
tt.Run("two errors", func(t *testing.T) {
|
tt.Run("two errors", func(t *testing.T) {
|
||||||
we := new(waiterr.WaitErr)
|
we := waiterr.New()
|
||||||
er1 := errors.New("error one")
|
er1 := errors.New("error one")
|
||||||
er2 := errors.New("error two")
|
er2 := errors.New("error two")
|
||||||
|
|
||||||
|
|
@ -118,7 +107,7 @@ func TestUnwrap(tt *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
tt.Run("no errors", func(t *testing.T) {
|
tt.Run("no errors", func(t *testing.T) {
|
||||||
weNoErr := new(waiterr.WaitErr)
|
weNoErr := waiterr.New()
|
||||||
weNoErr.Go(func() error { return nil })
|
weNoErr.Go(func() error { return nil })
|
||||||
weNoErr.Go(func() error { return nil })
|
weNoErr.Go(func() error { return nil })
|
||||||
_ = weNoErr.Wait()
|
_ = weNoErr.Wait()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue