88 lines
		
	
	
	
		
			1.8 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
		
		
			
		
	
	
			88 lines
		
	
	
	
		
			1.8 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| 
								 | 
							
								// Copyright 2013 The Go Authors. All rights reserved.
							 | 
						||
| 
								 | 
							
								// Use of this source code is governed by a BSD-style
							 | 
						||
| 
								 | 
							
								// license that can be found in the LICENSE file.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								package singleflight
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								import (
							 | 
						||
| 
								 | 
							
									"errors"
							 | 
						||
| 
								 | 
							
									"fmt"
							 | 
						||
| 
								 | 
							
									"sync"
							 | 
						||
| 
								 | 
							
									"sync/atomic"
							 | 
						||
| 
								 | 
							
									"testing"
							 | 
						||
| 
								 | 
							
									"time"
							 | 
						||
| 
								 | 
							
								)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								func TestDo(t *testing.T) {
							 | 
						||
| 
								 | 
							
									var g Group
							 | 
						||
| 
								 | 
							
									v, err, _ := g.Do("key", func() (interface{}, error) {
							 | 
						||
| 
								 | 
							
										return "bar", nil
							 | 
						||
| 
								 | 
							
									})
							 | 
						||
| 
								 | 
							
									if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want {
							 | 
						||
| 
								 | 
							
										t.Errorf("Do = %v; want %v", got, want)
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									if err != nil {
							 | 
						||
| 
								 | 
							
										t.Errorf("Do error = %v", err)
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								func TestDoErr(t *testing.T) {
							 | 
						||
| 
								 | 
							
									var g Group
							 | 
						||
| 
								 | 
							
									someErr := errors.New("Some error")
							 | 
						||
| 
								 | 
							
									v, err, _ := g.Do("key", func() (interface{}, error) {
							 | 
						||
| 
								 | 
							
										return nil, someErr
							 | 
						||
| 
								 | 
							
									})
							 | 
						||
| 
								 | 
							
									if err != someErr {
							 | 
						||
| 
								 | 
							
										t.Errorf("Do error = %v; want someErr %v", err, someErr)
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									if v != nil {
							 | 
						||
| 
								 | 
							
										t.Errorf("unexpected non-nil value %#v", v)
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								func TestDoDupSuppress(t *testing.T) {
							 | 
						||
| 
								 | 
							
									var g Group
							 | 
						||
| 
								 | 
							
									var wg1, wg2 sync.WaitGroup
							 | 
						||
| 
								 | 
							
									c := make(chan string, 1)
							 | 
						||
| 
								 | 
							
									var calls int32
							 | 
						||
| 
								 | 
							
									fn := func() (interface{}, error) {
							 | 
						||
| 
								 | 
							
										if atomic.AddInt32(&calls, 1) == 1 {
							 | 
						||
| 
								 | 
							
											// First invocation.
							 | 
						||
| 
								 | 
							
											wg1.Done()
							 | 
						||
| 
								 | 
							
										}
							 | 
						||
| 
								 | 
							
										v := <-c
							 | 
						||
| 
								 | 
							
										c <- v // pump; make available for any future calls
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
										time.Sleep(10 * time.Millisecond) // let more goroutines enter Do
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
										return v, nil
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									const n = 10
							 | 
						||
| 
								 | 
							
									wg1.Add(1)
							 | 
						||
| 
								 | 
							
									for i := 0; i < n; i++ {
							 | 
						||
| 
								 | 
							
										wg1.Add(1)
							 | 
						||
| 
								 | 
							
										wg2.Add(1)
							 | 
						||
| 
								 | 
							
										go func() {
							 | 
						||
| 
								 | 
							
											defer wg2.Done()
							 | 
						||
| 
								 | 
							
											wg1.Done()
							 | 
						||
| 
								 | 
							
											v, err, _ := g.Do("key", fn)
							 | 
						||
| 
								 | 
							
											if err != nil {
							 | 
						||
| 
								 | 
							
												t.Errorf("Do error: %v", err)
							 | 
						||
| 
								 | 
							
												return
							 | 
						||
| 
								 | 
							
											}
							 | 
						||
| 
								 | 
							
											if s, _ := v.(string); s != "bar" {
							 | 
						||
| 
								 | 
							
												t.Errorf("Do = %T %v; want %q", v, v, "bar")
							 | 
						||
| 
								 | 
							
											}
							 | 
						||
| 
								 | 
							
										}()
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									wg1.Wait()
							 | 
						||
| 
								 | 
							
									// At least one goroutine is in fn now and all of them have at
							 | 
						||
| 
								 | 
							
									// least reached the line before the Do.
							 | 
						||
| 
								 | 
							
									c <- "bar"
							 | 
						||
| 
								 | 
							
									wg2.Wait()
							 | 
						||
| 
								 | 
							
									if got := atomic.LoadInt32(&calls); got <= 0 || got >= n {
							 | 
						||
| 
								 | 
							
										t.Errorf("number of calls = %d; want over 0 and less than %d", got, n)
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								}
							 |