From edc34062e90f3dbd44fc41ecb4d22802aaa3c687 Mon Sep 17 00:00:00 2001 From: Dan Jones Date: Fri, 14 Nov 2025 14:04:51 -0600 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Refactor:=20Convert=20WaitErr=20to?= =?UTF-8?q?=20an=20interface=20and=20add=20New=20constructor?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 4 ++-- example_test.go | 7 +++---- waiterr.go | 53 ++++++++++++++++++++++++++++++++----------------- waiterr_test.go | 25 +++++++---------------- 4 files changed, 47 insertions(+), 42 deletions(-) diff --git a/README.md b/README.md index b503e38..e7923a0 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ import ( ) func main() { - we := new(waiterr.WaitErr) + we := waiterr.New() we.Go(func() error { time.Sleep(100 * time.Millisecond) @@ -60,7 +60,7 @@ func main() { } // You can also get the first error immediately - we2 := new(waiterr.WaitErr) + we2 := waiterr.New() we2.Go(func() error { time.Sleep(100 * time.Millisecond) return errors.New("first error from we2") diff --git a/example_test.go b/example_test.go index 3246ddd..2f1c9cd 100644 --- a/example_test.go +++ b/example_test.go @@ -9,8 +9,7 @@ import ( ) func Example() { - we := new(waiterr.WaitErr) - + we := waiterr.New() we.Go(func() error { time.Sleep(100 * time.Millisecond) fmt.Println("Goroutine 1 finished") @@ -43,7 +42,7 @@ func Example() { } func ExampleWaitErr_WaitForError() { - we := new(waiterr.WaitErr) + we := waiterr.New() we.Go(func() error { time.Sleep(100 * time.Millisecond) return errors.New("first error from we") @@ -63,7 +62,7 @@ func ExampleWaitErr_WaitForError() { } func ExampleWaitErr_Unwrap() { - we := new(waiterr.WaitErr) + we := waiterr.New() we.Go(func() error { time.Sleep(100 * time.Millisecond) return errors.New("first error from we") diff --git a/waiterr.go b/waiterr.go index 0ba8821..94d6cd6 100644 --- a/waiterr.go +++ b/waiterr.go @@ -5,23 +5,43 @@ import ( "sync" ) +func New() WaitErr { + var we waitErr + we.errCh = make(chan error, 1) + + return &we +} + +// WaitErr provides a way to run multiple goroutines and wait for their completion, +// collecting any errors they return. +type WaitErr interface { + // Go runs f in its own goroutine. When f returns, its error is stored, and returned + // with [WaitErr.Wait]. + Go(f func() error) + // WaitForError waits for the first error to be returned by one of our go routines, and immediately returns + // with that error. If all functions return successfully, a nil is returned. + WaitForError() error + // Wait for all current goroutines to finish. Return an error that combines all errors returned + // in the group so far (if any). + Wait() error + // Unwrap returns all non-nil errors returned by our functions. + // If no errors were returned, or all errors are nil, it returns nil. + Unwrap() []error +} + // WaitErr wraps a [sync.WaitGroup] with error handling. -type WaitErr struct { - wg sync.WaitGroup - errs []error - mut sync.RWMutex - firstErr error - firstErrOnce sync.Once - errCh chan error // Buffered channel of size 1 - initErrChOnce sync.Once +type waitErr struct { + wg sync.WaitGroup + errs []error + mut sync.RWMutex + firstErr error + firstErrOnce sync.Once + errCh chan error // Buffered channel of size 1 } // Go runs f in its own goroutine. When f returns, its error is stored, and returned // with [WaitErr.Wait]. -func (we *WaitErr) Go(f func() error) { - we.initErrChOnce.Do(func() { - we.errCh = make(chan error, 1) - }) +func (we *waitErr) Go(f func() error) { wrap := func() { err := f() @@ -48,10 +68,7 @@ func (we *WaitErr) Go(f func() error) { // WaitForError waits for the first error to be returned by one of our go routines, and immediately returns // with that error. If all functions return successfully, a nil is returned. It will panic if called before Go. -func (we *WaitErr) WaitForError() error { - if we.errCh == nil { - panic("WaitForError called before Go") - } +func (we *waitErr) WaitForError() error { // Check if an error has already been set we.mut.RLock() if we.firstErr != nil { @@ -82,7 +99,7 @@ func (we *WaitErr) WaitForError() error { // Wait for all current goroutines to finish. Return an error that combines all errors returned // in the group so far (if any). -func (we *WaitErr) Wait() error { +func (we *waitErr) Wait() error { we.wg.Wait() we.mut.RLock() defer we.mut.RUnlock() @@ -91,7 +108,7 @@ func (we *WaitErr) Wait() error { // Unwrap returns all non-nil errors returned by our functions. // If no errors were returned, or all errors are nil, it returns nil. -func (we *WaitErr) Unwrap() []error { +func (we *waitErr) Unwrap() []error { errs := make([]error, 0, len(we.errs)) for _, e := range we.errs { if e != nil { diff --git a/waiterr_test.go b/waiterr_test.go index f2ba7b2..c173675 100644 --- a/waiterr_test.go +++ b/waiterr_test.go @@ -11,7 +11,7 @@ import ( ) func TestGo(t *testing.T) { - we := new(waiterr.WaitErr) + we := waiterr.New() err := errors.New("uh-oh") var run bool we.Go(func() error { @@ -24,7 +24,7 @@ func TestGo(t *testing.T) { } func TestWait(t *testing.T) { - we := new(waiterr.WaitErr) + we := waiterr.New() er1 := errors.New("uh-oh") er2 := errors.New("oops") we.Go(func() error { return er1 }) @@ -45,7 +45,7 @@ func TestWait(t *testing.T) { func TestWaitForError(tt *testing.T) { tt.Run("first error", func(t *testing.T) { - we := new(waiterr.WaitErr) + we := waiterr.New() er1 := errors.New("uh-oh") er2 := errors.New("oops") we.Go(func() error { return nil }) @@ -58,7 +58,7 @@ func TestWaitForError(tt *testing.T) { }) tt.Run("no error", func(t *testing.T) { - we := new(waiterr.WaitErr) + we := waiterr.New() we.Go(func() error { return nil }) we.Go(func() error { return nil }) we.Go(func() error { return nil }) @@ -67,19 +67,8 @@ func TestWaitForError(tt *testing.T) { be.Err(t, err, nil) }) - tt.Run("panic", func(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Errorf("The code did not panic") - } - }() - - we := new(waiterr.WaitErr) - _ = we.WaitForError() - }) - tt.Run("first error set", func(tt2 *testing.T) { - we := new(waiterr.WaitErr) + we := waiterr.New() expectedErr := errors.New("pre-set error") synctest.Test(tt2, func(t *testing.T) { @@ -100,7 +89,7 @@ func TestWaitForError(tt *testing.T) { func TestUnwrap(tt *testing.T) { tt.Run("two errors", func(t *testing.T) { - we := new(waiterr.WaitErr) + we := waiterr.New() er1 := errors.New("error one") er2 := errors.New("error two") @@ -118,7 +107,7 @@ func TestUnwrap(tt *testing.T) { }) tt.Run("no errors", func(t *testing.T) { - weNoErr := new(waiterr.WaitErr) + weNoErr := waiterr.New() weNoErr.Go(func() error { return nil }) weNoErr.Go(func() error { return nil }) _ = weNoErr.Wait()