blob: 2624cf6c0b0320f662a60f5a7af00bf640c0e921 [file] [log] [blame]
package util
import (
"bytes"
"errors"
"fmt"
"sync"
)
// NamedErrGroup is like errgroup.Group, except each function in the group gets
// a name. It waits for all goroutines to finish and reports all errors by name.
type NamedErrGroup struct {
errs map[string]error
mtx sync.Mutex
wg sync.WaitGroup
}
// NewNamedErrGroup returns a NamedErrGroup instance.
func NewNamedErrGroup() *NamedErrGroup {
return &NamedErrGroup{
errs: map[string]error{},
mtx: sync.Mutex{},
wg: sync.WaitGroup{},
}
}
// Go runs the given function in a goroutine.
func (g *NamedErrGroup) Go(name string, fn func() error) {
g.wg.Add(1)
go func() {
defer g.wg.Done()
if err := fn(); err != nil {
g.mtx.Lock()
defer g.mtx.Unlock()
g.errs[name] = err
}
}()
}
// Wait waits for all of the goroutines to finish and reports any errors.
func (g *NamedErrGroup) Wait() error {
g.wg.Wait()
if len(g.errs) == 0 {
return nil
}
msg := bytes.NewBufferString("NamedErrGroup encountered errors:\n")
for name, err := range g.errs {
e := fmt.Sprintf("\t%s: %s\n", name, err)
_, _ = msg.Write([]byte(e))
}
return errors.New(string(msg.Bytes()))
}