122 lines
3.1 KiB
Go
122 lines
3.1 KiB
Go
package waiterr
|
|
|
|
import (
|
|
"errors"
|
|
"sync"
|
|
)
|
|
|
|
func New() WaitErr {
|
|
var we waitErr
|
|
we.errCh = make(chan error, 1)
|
|
|
|
return &we
|
|
}
|
|
|
|
// WaitErr provides a way to run multiple goroutines and wait for their completion,
|
|
// collecting any errors they return.
|
|
type WaitErr interface {
|
|
// Go runs f in its own goroutine. When f returns, its error is stored, and returned
|
|
// with [WaitErr.Wait].
|
|
Go(f func() error)
|
|
// WaitForError waits for the first error to be returned by one of our go routines, and immediately returns
|
|
// with that error. If all functions return successfully, a nil is returned.
|
|
WaitForError() error
|
|
// Wait for all current goroutines to finish. Return an error that combines all errors returned
|
|
// in the group so far (if any).
|
|
Wait() error
|
|
// Unwrap returns all non-nil errors returned by our functions.
|
|
// If no errors were returned, or all errors are nil, it returns nil.
|
|
Unwrap() []error
|
|
}
|
|
|
|
// WaitErr wraps a [sync.WaitGroup] with error handling.
|
|
type waitErr struct {
|
|
wg sync.WaitGroup
|
|
errs []error
|
|
mut sync.RWMutex
|
|
firstErr error
|
|
firstErrOnce sync.Once
|
|
errCh chan error // Buffered channel of size 1
|
|
}
|
|
|
|
// Go runs f in its own goroutine. When f returns, its error is stored, and returned
|
|
// with [WaitErr.Wait].
|
|
func (we *waitErr) Go(f func() error) {
|
|
wrap := func() {
|
|
err := f()
|
|
|
|
if err != nil {
|
|
we.firstErrOnce.Do(func() {
|
|
we.mut.Lock() // Acquire lock before writing to firstErr
|
|
we.firstErr = err
|
|
we.mut.Unlock() // Release lock after writing
|
|
|
|
// Non-blocking send to errCh
|
|
select {
|
|
case we.errCh <- err:
|
|
default:
|
|
}
|
|
})
|
|
}
|
|
|
|
we.mut.Lock()
|
|
defer we.mut.Unlock()
|
|
we.errs = append(we.errs, err)
|
|
}
|
|
we.wg.Go(wrap)
|
|
}
|
|
|
|
// WaitForError waits for the first error to be returned by one of our go routines, and immediately returns
|
|
// with that error. If all functions return successfully, a nil is returned. It will panic if called before Go.
|
|
func (we *waitErr) WaitForError() error {
|
|
// Check if an error has already been set
|
|
we.mut.RLock()
|
|
if we.firstErr != nil {
|
|
err := we.firstErr
|
|
we.mut.RUnlock()
|
|
return err
|
|
}
|
|
we.mut.RUnlock()
|
|
|
|
// Create a channel to signal when all goroutines are done
|
|
done := make(chan struct{})
|
|
go func() {
|
|
we.wg.Wait()
|
|
close(done)
|
|
}()
|
|
|
|
select {
|
|
case err := <-we.errCh:
|
|
return err
|
|
case <-done:
|
|
// All goroutines finished, and no error was sent to errCh
|
|
// Re-check firstErr in case it was set just before 'done' was closed
|
|
we.mut.RLock()
|
|
defer we.mut.RUnlock()
|
|
return we.firstErr // This will be nil if no error occurred
|
|
}
|
|
}
|
|
|
|
// Wait for all current goroutines to finish. Return an error that combines all errors returned
|
|
// in the group so far (if any).
|
|
func (we *waitErr) Wait() error {
|
|
we.wg.Wait()
|
|
we.mut.RLock()
|
|
defer we.mut.RUnlock()
|
|
return errors.Join(we.errs...)
|
|
}
|
|
|
|
// Unwrap returns all non-nil errors returned by our functions.
|
|
// If no errors were returned, or all errors are nil, it returns nil.
|
|
func (we *waitErr) Unwrap() []error {
|
|
errs := make([]error, 0, len(we.errs))
|
|
for _, e := range we.errs {
|
|
if e != nil {
|
|
errs = append(errs, e)
|
|
}
|
|
}
|
|
if len(errs) == 0 {
|
|
return nil
|
|
}
|
|
return errs
|
|
}
|