package waiterr import ( "errors" "sync" ) // WaitErr wraps a sync.WaitGroup with error handling. type WaitErr struct { wg sync.WaitGroup errs []error mut sync.RWMutex firstErr error firstErrOnce sync.Once errCh chan error // Buffered channel of size 1 initErrChOnce sync.Once } // Go runs f in its own goroutine. When f returns, its error is stored, and returned // with we.Wait(). func (we *WaitErr) Go(f func() error) { we.initErrChOnce.Do(func() { we.errCh = make(chan error, 1) }) wrap := func() { err := f() if err != nil { we.firstErrOnce.Do(func() { we.mut.Lock() // Acquire lock before writing to firstErr we.firstErr = err we.mut.Unlock() // Release lock after writing // Non-blocking send to errCh select { case we.errCh <- err: default: } }) } we.mut.Lock() defer we.mut.Unlock() we.errs = append(we.errs, err) } we.wg.Go(wrap) } // WaitForError waits for the first error to be returned by one of our go routines, and immediately returns // with that error. If all functions return successfully, a nil is returned. func (we *WaitErr) WaitForError() error { if we.errCh == nil { panic("WaitForError called before Go, errCh is nil") } // Check if an error has already been set we.mut.RLock() if we.firstErr != nil { err := we.firstErr we.mut.RUnlock() return err } we.mut.RUnlock() // Create a channel to signal when all goroutines are done done := make(chan struct{}) go func() { we.wg.Wait() close(done) }() select { case err := <-we.errCh: return err case <-done: // All goroutines finished, and no error was sent to errCh // Re-check firstErr in case it was set just before 'done' was closed we.mut.RLock() defer we.mut.RUnlock() return we.firstErr // This will be nil if no error occurred } } // Wait for all current goroutines to finish. Return an error that combines all errors returned // in the group so far (if any). func (we *WaitErr) Wait() error { we.wg.Wait() we.mut.RLock() defer we.mut.RUnlock() return errors.Join(we.errs...) } // Unwrap returns all non-nil errors returned by our functions. // If we.errs is empty, or all errors are nil, just return nil. func (we *WaitErr) Unwrap() []error { errs := make([]error, 0, len(we.errs)) for _, e := range we.errs { if e != nil { errs = append(errs, e) } } if len(errs) == 0 { return nil } return errs }