errgroup: add package
Package errgroup provides synchronization, error propagation, and Context cancellation for groups of goroutines working on subtasks of a common task. Change-Id: Ic9e51f6f846124076bbff9d53b0f09dc7fc5f2f0 Reviewed-on: https://go-review.googlesource.com/24894 Reviewed-by: Sameer Ajmani <sameer@golang.org> Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
This commit is contained in:
		
					parent
					
						
							
								c6cf2573d3
							
						
					
				
			
			
				commit
				
					
						457c582840
					
				
			
		
					 3 changed files with 344 additions and 0 deletions
				
			
		
							
								
								
									
										67
									
								
								errgroup/errgroup.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										67
									
								
								errgroup/errgroup.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,67 @@ | ||||||
|  | // Copyright 2016 The Go Authors. All rights reserved. | ||||||
|  | // Use of this source code is governed by a BSD-style | ||||||
|  | // license that can be found in the LICENSE file. | ||||||
|  | 
 | ||||||
|  | // Package errgroup provides synchronization, error propagation, and Context | ||||||
|  | // cancelation for groups of goroutines working on subtasks of a common task. | ||||||
|  | package errgroup | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"sync" | ||||||
|  | 
 | ||||||
|  | 	"golang.org/x/net/context" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // A Group is a collection of goroutines working on subtasks that are part of | ||||||
|  | // the same overall task. | ||||||
|  | // | ||||||
|  | // A zero Group is valid and does not cancel on error. | ||||||
|  | type Group struct { | ||||||
|  | 	cancel func() | ||||||
|  | 
 | ||||||
|  | 	wg sync.WaitGroup | ||||||
|  | 
 | ||||||
|  | 	errOnce sync.Once | ||||||
|  | 	err     error | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // WithContext returns a new Group and an associated Context derived from ctx. | ||||||
|  | // | ||||||
|  | // The derived Context is canceled the first time a function passed to Go | ||||||
|  | // returns a non-nil error or the first time Wait returns, whichever occurs | ||||||
|  | // first. | ||||||
|  | func WithContext(ctx context.Context) (*Group, context.Context) { | ||||||
|  | 	ctx, cancel := context.WithCancel(ctx) | ||||||
|  | 	return &Group{cancel: cancel}, ctx | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Wait blocks until all function calls from the Go method have returned, then | ||||||
|  | // returns the first non-nil error (if any) from them. | ||||||
|  | func (g *Group) Wait() error { | ||||||
|  | 	g.wg.Wait() | ||||||
|  | 	if g.cancel != nil { | ||||||
|  | 		g.cancel() | ||||||
|  | 	} | ||||||
|  | 	return g.err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Go calls the given function in a new goroutine. | ||||||
|  | // | ||||||
|  | // The first call to return a non-nil error cancels the group; its error will be | ||||||
|  | // returned by Wait. | ||||||
|  | func (g *Group) Go(f func() error) { | ||||||
|  | 	g.wg.Add(1) | ||||||
|  | 
 | ||||||
|  | 	go func() { | ||||||
|  | 		defer g.wg.Done() | ||||||
|  | 
 | ||||||
|  | 		if err := f(); err != nil { | ||||||
|  | 			g.errOnce.Do(func() { | ||||||
|  | 				g.err = err | ||||||
|  | 				if g.cancel != nil { | ||||||
|  | 					g.cancel() | ||||||
|  | 				} | ||||||
|  | 			}) | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
|  | } | ||||||
							
								
								
									
										101
									
								
								errgroup/errgroup_example_md5all_test.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										101
									
								
								errgroup/errgroup_example_md5all_test.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,101 @@ | ||||||
|  | // Copyright 2016 The Go Authors. All rights reserved. | ||||||
|  | // Use of this source code is governed by a BSD-style | ||||||
|  | // license that can be found in the LICENSE file. | ||||||
|  | 
 | ||||||
|  | package errgroup_test | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"crypto/md5" | ||||||
|  | 	"fmt" | ||||||
|  | 	"io/ioutil" | ||||||
|  | 	"log" | ||||||
|  | 	"os" | ||||||
|  | 	"path/filepath" | ||||||
|  | 
 | ||||||
|  | 	"golang.org/x/net/context" | ||||||
|  | 	"golang.org/x/sync/errgroup" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // Pipeline demonstrates the use of a Group to implement a multi-stage | ||||||
|  | // pipeline: a version of the MD5All function with bounded parallelism from | ||||||
|  | // https://blog.golang.org/pipelines. | ||||||
|  | func ExampleGroup_pipeline() { | ||||||
|  | 	m, err := MD5All(context.Background(), ".") | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for k, sum := range m { | ||||||
|  | 		fmt.Printf("%s:\t%x\n", k, sum) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type result struct { | ||||||
|  | 	path string | ||||||
|  | 	sum  [md5.Size]byte | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // MD5All reads all the files in the file tree rooted at root and returns a map | ||||||
|  | // from file path to the MD5 sum of the file's contents. If the directory walk | ||||||
|  | // fails or any read operation fails, MD5All returns an error. | ||||||
|  | func MD5All(ctx context.Context, root string) (map[string][md5.Size]byte, error) { | ||||||
|  | 	// ctx is canceled when MD5All calls g.Wait(). When this version of MD5All | ||||||
|  | 	// returns - even in case of error! - we know that all of the goroutines have | ||||||
|  | 	// finished and the memory they were using can be garbage-collected. | ||||||
|  | 	g, ctx := errgroup.WithContext(ctx) | ||||||
|  | 	paths := make(chan string) | ||||||
|  | 
 | ||||||
|  | 	g.Go(func() error { | ||||||
|  | 		defer close(paths) | ||||||
|  | 		return filepath.Walk(root, func(path string, info os.FileInfo, err error) error { | ||||||
|  | 			if err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 			if !info.Mode().IsRegular() { | ||||||
|  | 				return nil | ||||||
|  | 			} | ||||||
|  | 			select { | ||||||
|  | 			case paths <- path: | ||||||
|  | 			case <-ctx.Done(): | ||||||
|  | 				return ctx.Err() | ||||||
|  | 			} | ||||||
|  | 			return nil | ||||||
|  | 		}) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	// Start a fixed number of goroutines to read and digest files. | ||||||
|  | 	c := make(chan result) | ||||||
|  | 	const numDigesters = 20 | ||||||
|  | 	for i := 0; i < numDigesters; i++ { | ||||||
|  | 		g.Go(func() error { | ||||||
|  | 			for path := range paths { | ||||||
|  | 				data, err := ioutil.ReadFile(path) | ||||||
|  | 				if err != nil { | ||||||
|  | 					return err | ||||||
|  | 				} | ||||||
|  | 				select { | ||||||
|  | 				case c <- result{path, md5.Sum(data)}: | ||||||
|  | 				case <-ctx.Done(): | ||||||
|  | 					return ctx.Err() | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 			return nil | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | 	go func() { | ||||||
|  | 		g.Wait() | ||||||
|  | 		close(c) | ||||||
|  | 	}() | ||||||
|  | 
 | ||||||
|  | 	m := make(map[string][md5.Size]byte) | ||||||
|  | 	for r := range c { | ||||||
|  | 		m[r.path] = r.sum | ||||||
|  | 	} | ||||||
|  | 	// Check whether any of the goroutines failed. Since g is accumulating the | ||||||
|  | 	// errors, we don't need to send them (or check for them) in the individual | ||||||
|  | 	// results sent on the channel. | ||||||
|  | 	if err := g.Wait(); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	return m, nil | ||||||
|  | } | ||||||
							
								
								
									
										176
									
								
								errgroup/errgroup_test.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										176
									
								
								errgroup/errgroup_test.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,176 @@ | ||||||
|  | // Copyright 2016 The Go Authors. All rights reserved. | ||||||
|  | // Use of this source code is governed by a BSD-style | ||||||
|  | // license that can be found in the LICENSE file. | ||||||
|  | 
 | ||||||
|  | package errgroup_test | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"errors" | ||||||
|  | 	"fmt" | ||||||
|  | 	"net/http" | ||||||
|  | 	"os" | ||||||
|  | 	"testing" | ||||||
|  | 
 | ||||||
|  | 	"golang.org/x/net/context" | ||||||
|  | 	"golang.org/x/sync/errgroup" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | var ( | ||||||
|  | 	Web   = fakeSearch("web") | ||||||
|  | 	Image = fakeSearch("image") | ||||||
|  | 	Video = fakeSearch("video") | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | type Result string | ||||||
|  | type Search func(ctx context.Context, query string) (Result, error) | ||||||
|  | 
 | ||||||
|  | func fakeSearch(kind string) Search { | ||||||
|  | 	return func(_ context.Context, query string) (Result, error) { | ||||||
|  | 		return Result(fmt.Sprintf("%s result for %q", kind, query)), nil | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // JustErrors illustrates the use of a Group in place of a sync.WaitGroup to | ||||||
|  | // simplify goroutine counting and error handling. This example is derived from | ||||||
|  | // the sync.WaitGroup example at https://golang.org/pkg/sync/#example_WaitGroup. | ||||||
|  | func ExampleGroup_justErrors() { | ||||||
|  | 	var g errgroup.Group | ||||||
|  | 	var urls = []string{ | ||||||
|  | 		"http://www.golang.org/", | ||||||
|  | 		"http://www.google.com/", | ||||||
|  | 		"http://www.somestupidname.com/", | ||||||
|  | 	} | ||||||
|  | 	for _, url := range urls { | ||||||
|  | 		// Launch a goroutine to fetch the URL. | ||||||
|  | 		url := url // https://golang.org/doc/faq#closures_and_goroutines | ||||||
|  | 		g.Go(func(url string) error { | ||||||
|  | 			// Fetch the URL. | ||||||
|  | 			resp, err := http.Get(url) | ||||||
|  | 			if err == nil { | ||||||
|  | 				resp.Body.Close() | ||||||
|  | 			} | ||||||
|  | 			return err | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | 	// Wait for all HTTP fetches to complete. | ||||||
|  | 	if err := wg.Wait(); err == nil { | ||||||
|  | 		fmt.Println("Successfully fetched all URLs.") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Parallel illustrates the use of a Group for synchronizing a simple parallel | ||||||
|  | // task: the "Google Search 2.0" function from | ||||||
|  | // https://talks.golang.org/2012/concurrency.slide#46, augmented with a Context | ||||||
|  | // and error-handling. | ||||||
|  | func ExampleGroup_parallel() { | ||||||
|  | 	Google := func(ctx context.Context, query string) ([]Result, error) { | ||||||
|  | 		g, ctx := errgroup.WithContext(ctx) | ||||||
|  | 
 | ||||||
|  | 		searches := []Search{Web, Image, Video} | ||||||
|  | 		results := make([]Result, len(searches)) | ||||||
|  | 		for i, search := range searches { | ||||||
|  | 			i, search := i, search // https://golang.org/doc/faq#closures_and_goroutines | ||||||
|  | 			g.Go(func() error { | ||||||
|  | 				result, err := search(ctx, query) | ||||||
|  | 				if err == nil { | ||||||
|  | 					results[i] = result | ||||||
|  | 				} | ||||||
|  | 				return err | ||||||
|  | 			}) | ||||||
|  | 		} | ||||||
|  | 		if err := g.Wait(); err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 		return results, nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	results, err := Google(context.Background(), "golang") | ||||||
|  | 	if err != nil { | ||||||
|  | 		fmt.Fprintln(os.Stderr, err) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	for _, result := range results { | ||||||
|  | 		fmt.Println(result) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Output: | ||||||
|  | 	// web result for "golang" | ||||||
|  | 	// image result for "golang" | ||||||
|  | 	// video result for "golang" | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestZeroGroup(t *testing.T) { | ||||||
|  | 	err1 := errors.New("errgroup_test: 1") | ||||||
|  | 	err2 := errors.New("errgroup_test: 2") | ||||||
|  | 
 | ||||||
|  | 	cases := []struct { | ||||||
|  | 		errs []error | ||||||
|  | 	}{ | ||||||
|  | 		{errs: []error{}}, | ||||||
|  | 		{errs: []error{nil}}, | ||||||
|  | 		{errs: []error{err1}}, | ||||||
|  | 		{errs: []error{err1, nil}}, | ||||||
|  | 		{errs: []error{err1, nil, err2}}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, tc := range cases { | ||||||
|  | 		var g errgroup.Group | ||||||
|  | 
 | ||||||
|  | 		var firstErr error | ||||||
|  | 		for i, err := range tc.errs { | ||||||
|  | 			err := err | ||||||
|  | 			g.Go(func() error { return err }) | ||||||
|  | 
 | ||||||
|  | 			if firstErr == nil && err != nil { | ||||||
|  | 				firstErr = err | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if gErr := g.Wait(); gErr != firstErr { | ||||||
|  | 				t.Errorf("after %T.Go(func() error { return err }) for err in %v\n"+ | ||||||
|  | 					"g.Wait() = %v; want %v", | ||||||
|  | 					g, tc.errs[:i+1], err, firstErr) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestWithContext(t *testing.T) { | ||||||
|  | 	errDoom := errors.New("group_test: doomed") | ||||||
|  | 
 | ||||||
|  | 	cases := []struct { | ||||||
|  | 		errs []error | ||||||
|  | 		want error | ||||||
|  | 	}{ | ||||||
|  | 		{want: nil}, | ||||||
|  | 		{errs: []error{nil}, want: nil}, | ||||||
|  | 		{errs: []error{errDoom}, want: errDoom}, | ||||||
|  | 		{errs: []error{errDoom, nil}, want: errDoom}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, tc := range cases { | ||||||
|  | 		g, ctx := errgroup.WithContext(context.Background()) | ||||||
|  | 
 | ||||||
|  | 		for _, err := range tc.errs { | ||||||
|  | 			err := err | ||||||
|  | 			g.Go(func() error { return err }) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if err := g.Wait(); err != tc.want { | ||||||
|  | 			t.Errorf("after %T.Go(func() error { return err }) for err in %v\n"+ | ||||||
|  | 				"g.Wait() = %v; want %v", | ||||||
|  | 				g, tc.errs, err, tc.want) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		canceled := false | ||||||
|  | 		select { | ||||||
|  | 		case <-ctx.Done(): | ||||||
|  | 			canceled = true | ||||||
|  | 		default: | ||||||
|  | 		} | ||||||
|  | 		if !canceled { | ||||||
|  | 			t.Errorf("after %T.Go(func() error { return err }) for err in %v\n"+ | ||||||
|  | 				"ctx.Done() was not closed", | ||||||
|  | 				g, tc.errs) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue