waiterr/waiterr.go

122 lines
3.1 KiB
Go

package waiterr
import (
"errors"
"sync"
)
func New() WaitErr {
var we waitErr
we.errCh = make(chan error, 1)
return &we
}
// WaitErr provides a way to run multiple goroutines and wait for their completion,
// collecting any errors they return.
type WaitErr interface {
// Go runs f in its own goroutine. When f returns, its error is stored, and returned
// with [WaitErr.Wait].
Go(f func() error)
// WaitForError waits for the first error to be returned by one of our go routines, and immediately returns
// with that error. If all functions return successfully, a nil is returned.
WaitForError() error
// Wait for all current goroutines to finish. Return an error that combines all errors returned
// in the group so far (if any).
Wait() error
// Unwrap returns all non-nil errors returned by our functions.
// If no errors were returned, or all errors are nil, it returns nil.
Unwrap() []error
}
// WaitErr wraps a [sync.WaitGroup] with error handling.
type waitErr struct {
wg sync.WaitGroup
errs []error
mut sync.RWMutex
firstErr error
firstErrOnce sync.Once
errCh chan error // Buffered channel of size 1
}
// Go runs f in its own goroutine. When f returns, its error is stored, and returned
// with [WaitErr.Wait].
func (we *waitErr) Go(f func() error) {
wrap := func() {
err := f()
if err != nil {
we.firstErrOnce.Do(func() {
we.mut.Lock() // Acquire lock before writing to firstErr
we.firstErr = err
we.mut.Unlock() // Release lock after writing
// Non-blocking send to errCh
select {
case we.errCh <- err:
default:
}
})
}
we.mut.Lock()
defer we.mut.Unlock()
we.errs = append(we.errs, err)
}
we.wg.Go(wrap)
}
// WaitForError waits for the first error to be returned by one of our go routines, and immediately returns
// with that error. If all functions return successfully, a nil is returned. It will panic if called before Go.
func (we *waitErr) WaitForError() error {
// Check if an error has already been set
we.mut.RLock()
if we.firstErr != nil {
err := we.firstErr
we.mut.RUnlock()
return err
}
we.mut.RUnlock()
// Create a channel to signal when all goroutines are done
done := make(chan struct{})
go func() {
we.wg.Wait()
close(done)
}()
select {
case err := <-we.errCh:
return err
case <-done:
// All goroutines finished, and no error was sent to errCh
// Re-check firstErr in case it was set just before 'done' was closed
we.mut.RLock()
defer we.mut.RUnlock()
return we.firstErr // This will be nil if no error occurred
}
}
// Wait for all current goroutines to finish. Return an error that combines all errors returned
// in the group so far (if any).
func (we *waitErr) Wait() error {
we.wg.Wait()
we.mut.RLock()
defer we.mut.RUnlock()
return errors.Join(we.errs...)
}
// Unwrap returns all non-nil errors returned by our functions.
// If no errors were returned, or all errors are nil, it returns nil.
func (we *waitErr) Unwrap() []error {
errs := make([]error, 0, len(we.errs))
for _, e := range we.errs {
if e != nil {
errs = append(errs, e)
}
}
if len(errs) == 0 {
return nil
}
return errs
}