...

Source file src/context/x_test.go

Documentation: context

     1  // Copyright 2016 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package context_test
     6  
     7  import (
     8  	. "context"
     9  	"errors"
    10  	"fmt"
    11  	"math/rand"
    12  	"runtime"
    13  	"strings"
    14  	"sync"
    15  	"testing"
    16  	"time"
    17  )
    18  
    19  // Each XTestFoo in context_test.go must be called from a TestFoo here to run.
    20  func TestParentFinishesChild(t *testing.T) {
    21  	XTestParentFinishesChild(t) // uses unexported context types
    22  }
    23  func TestChildFinishesFirst(t *testing.T) {
    24  	XTestChildFinishesFirst(t) // uses unexported context types
    25  }
    26  func TestCancelRemoves(t *testing.T) {
    27  	XTestCancelRemoves(t) // uses unexported context types
    28  }
    29  func TestCustomContextGoroutines(t *testing.T) {
    30  	XTestCustomContextGoroutines(t) // reads the context.goroutines counter
    31  }
    32  
    33  // The following are regular tests in package context_test.
    34  
    35  // otherContext is a Context that's not one of the types defined in context.go.
    36  // This lets us test code paths that differ based on the underlying type of the
    37  // Context.
    38  type otherContext struct {
    39  	Context
    40  }
    41  
    42  const (
    43  	shortDuration    = 1 * time.Millisecond // a reasonable duration to block in a test
    44  	veryLongDuration = 1000 * time.Hour     // an arbitrary upper bound on the test's running time
    45  )
    46  
    47  // quiescent returns an arbitrary duration by which the program should have
    48  // completed any remaining work and reached a steady (idle) state.
    49  func quiescent(t *testing.T) time.Duration {
    50  	deadline, ok := t.Deadline()
    51  	if !ok {
    52  		return 5 * time.Second
    53  	}
    54  
    55  	const arbitraryCleanupMargin = 1 * time.Second
    56  	return time.Until(deadline) - arbitraryCleanupMargin
    57  }
    58  func TestBackground(t *testing.T) {
    59  	c := Background()
    60  	if c == nil {
    61  		t.Fatalf("Background returned nil")
    62  	}
    63  	select {
    64  	case x := <-c.Done():
    65  		t.Errorf("<-c.Done() == %v want nothing (it should block)", x)
    66  	default:
    67  	}
    68  	if got, want := fmt.Sprint(c), "context.Background"; got != want {
    69  		t.Errorf("Background().String() = %q want %q", got, want)
    70  	}
    71  }
    72  
    73  func TestTODO(t *testing.T) {
    74  	c := TODO()
    75  	if c == nil {
    76  		t.Fatalf("TODO returned nil")
    77  	}
    78  	select {
    79  	case x := <-c.Done():
    80  		t.Errorf("<-c.Done() == %v want nothing (it should block)", x)
    81  	default:
    82  	}
    83  	if got, want := fmt.Sprint(c), "context.TODO"; got != want {
    84  		t.Errorf("TODO().String() = %q want %q", got, want)
    85  	}
    86  }
    87  
    88  func TestWithCancel(t *testing.T) {
    89  	c1, cancel := WithCancel(Background())
    90  
    91  	if got, want := fmt.Sprint(c1), "context.Background.WithCancel"; got != want {
    92  		t.Errorf("c1.String() = %q want %q", got, want)
    93  	}
    94  
    95  	o := otherContext{c1}
    96  	c2, _ := WithCancel(o)
    97  	contexts := []Context{c1, o, c2}
    98  
    99  	for i, c := range contexts {
   100  		if d := c.Done(); d == nil {
   101  			t.Errorf("c[%d].Done() == %v want non-nil", i, d)
   102  		}
   103  		if e := c.Err(); e != nil {
   104  			t.Errorf("c[%d].Err() == %v want nil", i, e)
   105  		}
   106  
   107  		select {
   108  		case x := <-c.Done():
   109  			t.Errorf("<-c.Done() == %v want nothing (it should block)", x)
   110  		default:
   111  		}
   112  	}
   113  
   114  	cancel() // Should propagate synchronously.
   115  	for i, c := range contexts {
   116  		select {
   117  		case <-c.Done():
   118  		default:
   119  			t.Errorf("<-c[%d].Done() blocked, but shouldn't have", i)
   120  		}
   121  		if e := c.Err(); e != Canceled {
   122  			t.Errorf("c[%d].Err() == %v want %v", i, e, Canceled)
   123  		}
   124  	}
   125  }
   126  
   127  func testDeadline(c Context, name string, t *testing.T) {
   128  	t.Helper()
   129  	d := quiescent(t)
   130  	timer := time.NewTimer(d)
   131  	defer timer.Stop()
   132  	select {
   133  	case <-timer.C:
   134  		t.Fatalf("%s: context not timed out after %v", name, d)
   135  	case <-c.Done():
   136  	}
   137  	if e := c.Err(); e != DeadlineExceeded {
   138  		t.Errorf("%s: c.Err() == %v; want %v", name, e, DeadlineExceeded)
   139  	}
   140  }
   141  
   142  func TestDeadline(t *testing.T) {
   143  	t.Parallel()
   144  
   145  	c, _ := WithDeadline(Background(), time.Now().Add(shortDuration))
   146  	if got, prefix := fmt.Sprint(c), "context.Background.WithDeadline("; !strings.HasPrefix(got, prefix) {
   147  		t.Errorf("c.String() = %q want prefix %q", got, prefix)
   148  	}
   149  	testDeadline(c, "WithDeadline", t)
   150  
   151  	c, _ = WithDeadline(Background(), time.Now().Add(shortDuration))
   152  	o := otherContext{c}
   153  	testDeadline(o, "WithDeadline+otherContext", t)
   154  
   155  	c, _ = WithDeadline(Background(), time.Now().Add(shortDuration))
   156  	o = otherContext{c}
   157  	c, _ = WithDeadline(o, time.Now().Add(veryLongDuration))
   158  	testDeadline(c, "WithDeadline+otherContext+WithDeadline", t)
   159  
   160  	c, _ = WithDeadline(Background(), time.Now().Add(-shortDuration))
   161  	testDeadline(c, "WithDeadline+inthepast", t)
   162  
   163  	c, _ = WithDeadline(Background(), time.Now())
   164  	testDeadline(c, "WithDeadline+now", t)
   165  }
   166  
   167  func TestTimeout(t *testing.T) {
   168  	t.Parallel()
   169  
   170  	c, _ := WithTimeout(Background(), shortDuration)
   171  	if got, prefix := fmt.Sprint(c), "context.Background.WithDeadline("; !strings.HasPrefix(got, prefix) {
   172  		t.Errorf("c.String() = %q want prefix %q", got, prefix)
   173  	}
   174  	testDeadline(c, "WithTimeout", t)
   175  
   176  	c, _ = WithTimeout(Background(), shortDuration)
   177  	o := otherContext{c}
   178  	testDeadline(o, "WithTimeout+otherContext", t)
   179  
   180  	c, _ = WithTimeout(Background(), shortDuration)
   181  	o = otherContext{c}
   182  	c, _ = WithTimeout(o, veryLongDuration)
   183  	testDeadline(c, "WithTimeout+otherContext+WithTimeout", t)
   184  }
   185  
   186  func TestCanceledTimeout(t *testing.T) {
   187  	c, _ := WithTimeout(Background(), time.Second)
   188  	o := otherContext{c}
   189  	c, cancel := WithTimeout(o, veryLongDuration)
   190  	cancel() // Should propagate synchronously.
   191  	select {
   192  	case <-c.Done():
   193  	default:
   194  		t.Errorf("<-c.Done() blocked, but shouldn't have")
   195  	}
   196  	if e := c.Err(); e != Canceled {
   197  		t.Errorf("c.Err() == %v want %v", e, Canceled)
   198  	}
   199  }
   200  
   201  type key1 int
   202  type key2 int
   203  
   204  func (k key2) String() string { return fmt.Sprintf("%[1]T(%[1]d)", k) }
   205  
   206  var k1 = key1(1)
   207  var k2 = key2(1) // same int as k1, different type
   208  var k3 = key2(3) // same type as k2, different int
   209  
   210  func TestValues(t *testing.T) {
   211  	check := func(c Context, nm, v1, v2, v3 string) {
   212  		if v, ok := c.Value(k1).(string); ok == (len(v1) == 0) || v != v1 {
   213  			t.Errorf(`%s.Value(k1).(string) = %q, %t want %q, %t`, nm, v, ok, v1, len(v1) != 0)
   214  		}
   215  		if v, ok := c.Value(k2).(string); ok == (len(v2) == 0) || v != v2 {
   216  			t.Errorf(`%s.Value(k2).(string) = %q, %t want %q, %t`, nm, v, ok, v2, len(v2) != 0)
   217  		}
   218  		if v, ok := c.Value(k3).(string); ok == (len(v3) == 0) || v != v3 {
   219  			t.Errorf(`%s.Value(k3).(string) = %q, %t want %q, %t`, nm, v, ok, v3, len(v3) != 0)
   220  		}
   221  	}
   222  
   223  	c0 := Background()
   224  	check(c0, "c0", "", "", "")
   225  
   226  	c1 := WithValue(Background(), k1, "c1k1")
   227  	check(c1, "c1", "c1k1", "", "")
   228  
   229  	if got, want := fmt.Sprint(c1), `context.Background.WithValue(context_test.key1, c1k1)`; got != want {
   230  		t.Errorf("c.String() = %q want %q", got, want)
   231  	}
   232  
   233  	c2 := WithValue(c1, k2, "c2k2")
   234  	check(c2, "c2", "c1k1", "c2k2", "")
   235  
   236  	if got, want := fmt.Sprint(c2), `context.Background.WithValue(context_test.key1, c1k1).WithValue(context_test.key2(1), c2k2)`; got != want {
   237  		t.Errorf("c.String() = %q want %q", got, want)
   238  	}
   239  
   240  	c3 := WithValue(c2, k3, "c3k3")
   241  	check(c3, "c2", "c1k1", "c2k2", "c3k3")
   242  
   243  	c4 := WithValue(c3, k1, nil)
   244  	check(c4, "c4", "", "c2k2", "c3k3")
   245  
   246  	if got, want := fmt.Sprint(c4), `context.Background.WithValue(context_test.key1, c1k1).WithValue(context_test.key2(1), c2k2).WithValue(context_test.key2(3), c3k3).WithValue(context_test.key1, <nil>)`; got != want {
   247  		t.Errorf("c.String() = %q want %q", got, want)
   248  	}
   249  
   250  	o0 := otherContext{Background()}
   251  	check(o0, "o0", "", "", "")
   252  
   253  	o1 := otherContext{WithValue(Background(), k1, "c1k1")}
   254  	check(o1, "o1", "c1k1", "", "")
   255  
   256  	o2 := WithValue(o1, k2, "o2k2")
   257  	check(o2, "o2", "c1k1", "o2k2", "")
   258  
   259  	o3 := otherContext{c4}
   260  	check(o3, "o3", "", "c2k2", "c3k3")
   261  
   262  	o4 := WithValue(o3, k3, nil)
   263  	check(o4, "o4", "", "c2k2", "")
   264  }
   265  
   266  func TestAllocs(t *testing.T) {
   267  	bg := Background()
   268  	for _, test := range []struct {
   269  		desc       string
   270  		f          func()
   271  		limit      float64
   272  		gccgoLimit float64
   273  	}{
   274  		{
   275  			desc:       "Background()",
   276  			f:          func() { Background() },
   277  			limit:      0,
   278  			gccgoLimit: 0,
   279  		},
   280  		{
   281  			desc: fmt.Sprintf("WithValue(bg, %v, nil)", k1),
   282  			f: func() {
   283  				c := WithValue(bg, k1, nil)
   284  				c.Value(k1)
   285  			},
   286  			limit:      3,
   287  			gccgoLimit: 3,
   288  		},
   289  		{
   290  			desc: "WithTimeout(bg, 1*time.Nanosecond)",
   291  			f: func() {
   292  				c, _ := WithTimeout(bg, 1*time.Nanosecond)
   293  				<-c.Done()
   294  			},
   295  			limit:      12,
   296  			gccgoLimit: 15,
   297  		},
   298  		{
   299  			desc: "WithCancel(bg)",
   300  			f: func() {
   301  				c, cancel := WithCancel(bg)
   302  				cancel()
   303  				<-c.Done()
   304  			},
   305  			limit:      5,
   306  			gccgoLimit: 8,
   307  		},
   308  		{
   309  			desc: "WithTimeout(bg, 5*time.Millisecond)",
   310  			f: func() {
   311  				c, cancel := WithTimeout(bg, 5*time.Millisecond)
   312  				cancel()
   313  				<-c.Done()
   314  			},
   315  			limit:      8,
   316  			gccgoLimit: 25,
   317  		},
   318  	} {
   319  		limit := test.limit
   320  		if runtime.Compiler == "gccgo" {
   321  			// gccgo does not yet do escape analysis.
   322  			// TODO(iant): Remove this when gccgo does do escape analysis.
   323  			limit = test.gccgoLimit
   324  		}
   325  		numRuns := 100
   326  		if testing.Short() {
   327  			numRuns = 10
   328  		}
   329  		if n := testing.AllocsPerRun(numRuns, test.f); n > limit {
   330  			t.Errorf("%s allocs = %f want %d", test.desc, n, int(limit))
   331  		}
   332  	}
   333  }
   334  
   335  func TestSimultaneousCancels(t *testing.T) {
   336  	root, cancel := WithCancel(Background())
   337  	m := map[Context]CancelFunc{root: cancel}
   338  	q := []Context{root}
   339  	// Create a tree of contexts.
   340  	for len(q) != 0 && len(m) < 100 {
   341  		parent := q[0]
   342  		q = q[1:]
   343  		for i := 0; i < 4; i++ {
   344  			ctx, cancel := WithCancel(parent)
   345  			m[ctx] = cancel
   346  			q = append(q, ctx)
   347  		}
   348  	}
   349  	// Start all the cancels in a random order.
   350  	var wg sync.WaitGroup
   351  	wg.Add(len(m))
   352  	for _, cancel := range m {
   353  		go func(cancel CancelFunc) {
   354  			cancel()
   355  			wg.Done()
   356  		}(cancel)
   357  	}
   358  
   359  	d := quiescent(t)
   360  	stuck := make(chan struct{})
   361  	timer := time.AfterFunc(d, func() { close(stuck) })
   362  	defer timer.Stop()
   363  
   364  	// Wait on all the contexts in a random order.
   365  	for ctx := range m {
   366  		select {
   367  		case <-ctx.Done():
   368  		case <-stuck:
   369  			buf := make([]byte, 10<<10)
   370  			n := runtime.Stack(buf, true)
   371  			t.Fatalf("timed out after %v waiting for <-ctx.Done(); stacks:\n%s", d, buf[:n])
   372  		}
   373  	}
   374  	// Wait for all the cancel functions to return.
   375  	done := make(chan struct{})
   376  	go func() {
   377  		wg.Wait()
   378  		close(done)
   379  	}()
   380  	select {
   381  	case <-done:
   382  	case <-stuck:
   383  		buf := make([]byte, 10<<10)
   384  		n := runtime.Stack(buf, true)
   385  		t.Fatalf("timed out after %v waiting for cancel functions; stacks:\n%s", d, buf[:n])
   386  	}
   387  }
   388  
   389  func TestInterlockedCancels(t *testing.T) {
   390  	parent, cancelParent := WithCancel(Background())
   391  	child, cancelChild := WithCancel(parent)
   392  	go func() {
   393  		<-parent.Done()
   394  		cancelChild()
   395  	}()
   396  	cancelParent()
   397  	d := quiescent(t)
   398  	timer := time.NewTimer(d)
   399  	defer timer.Stop()
   400  	select {
   401  	case <-child.Done():
   402  	case <-timer.C:
   403  		buf := make([]byte, 10<<10)
   404  		n := runtime.Stack(buf, true)
   405  		t.Fatalf("timed out after %v waiting for child.Done(); stacks:\n%s", d, buf[:n])
   406  	}
   407  }
   408  
   409  func TestLayersCancel(t *testing.T) {
   410  	testLayers(t, time.Now().UnixNano(), false)
   411  }
   412  
   413  func TestLayersTimeout(t *testing.T) {
   414  	testLayers(t, time.Now().UnixNano(), true)
   415  }
   416  
   417  func testLayers(t *testing.T, seed int64, testTimeout bool) {
   418  	t.Parallel()
   419  
   420  	r := rand.New(rand.NewSource(seed))
   421  	prefix := fmt.Sprintf("seed=%d", seed)
   422  	errorf := func(format string, a ...any) {
   423  		t.Errorf(prefix+format, a...)
   424  	}
   425  	const (
   426  		minLayers = 30
   427  	)
   428  	type value int
   429  	var (
   430  		vals      []*value
   431  		cancels   []CancelFunc
   432  		numTimers int
   433  		ctx       = Background()
   434  	)
   435  	for i := 0; i < minLayers || numTimers == 0 || len(cancels) == 0 || len(vals) == 0; i++ {
   436  		switch r.Intn(3) {
   437  		case 0:
   438  			v := new(value)
   439  			ctx = WithValue(ctx, v, v)
   440  			vals = append(vals, v)
   441  		case 1:
   442  			var cancel CancelFunc
   443  			ctx, cancel = WithCancel(ctx)
   444  			cancels = append(cancels, cancel)
   445  		case 2:
   446  			var cancel CancelFunc
   447  			d := veryLongDuration
   448  			if testTimeout {
   449  				d = shortDuration
   450  			}
   451  			ctx, cancel = WithTimeout(ctx, d)
   452  			cancels = append(cancels, cancel)
   453  			numTimers++
   454  		}
   455  	}
   456  	checkValues := func(when string) {
   457  		for _, key := range vals {
   458  			if val := ctx.Value(key).(*value); key != val {
   459  				errorf("%s: ctx.Value(%p) = %p want %p", when, key, val, key)
   460  			}
   461  		}
   462  	}
   463  	if !testTimeout {
   464  		select {
   465  		case <-ctx.Done():
   466  			errorf("ctx should not be canceled yet")
   467  		default:
   468  		}
   469  	}
   470  	if s, prefix := fmt.Sprint(ctx), "context.Background."; !strings.HasPrefix(s, prefix) {
   471  		t.Errorf("ctx.String() = %q want prefix %q", s, prefix)
   472  	}
   473  	t.Log(ctx)
   474  	checkValues("before cancel")
   475  	if testTimeout {
   476  		d := quiescent(t)
   477  		timer := time.NewTimer(d)
   478  		defer timer.Stop()
   479  		select {
   480  		case <-ctx.Done():
   481  		case <-timer.C:
   482  			errorf("ctx should have timed out after %v", d)
   483  		}
   484  		checkValues("after timeout")
   485  	} else {
   486  		cancel := cancels[r.Intn(len(cancels))]
   487  		cancel()
   488  		select {
   489  		case <-ctx.Done():
   490  		default:
   491  			errorf("ctx should be canceled")
   492  		}
   493  		checkValues("after cancel")
   494  	}
   495  }
   496  
   497  func TestWithCancelCanceledParent(t *testing.T) {
   498  	parent, pcancel := WithCancelCause(Background())
   499  	cause := fmt.Errorf("Because!")
   500  	pcancel(cause)
   501  
   502  	c, _ := WithCancel(parent)
   503  	select {
   504  	case <-c.Done():
   505  	default:
   506  		t.Errorf("child not done immediately upon construction")
   507  	}
   508  	if got, want := c.Err(), Canceled; got != want {
   509  		t.Errorf("child not canceled; got = %v, want = %v", got, want)
   510  	}
   511  	if got, want := Cause(c), cause; got != want {
   512  		t.Errorf("child has wrong cause; got = %v, want = %v", got, want)
   513  	}
   514  }
   515  
   516  func TestWithCancelSimultaneouslyCanceledParent(t *testing.T) {
   517  	// Cancel the parent goroutine concurrently with creating a child.
   518  	for i := 0; i < 100; i++ {
   519  		parent, pcancel := WithCancelCause(Background())
   520  		cause := fmt.Errorf("Because!")
   521  		go pcancel(cause)
   522  
   523  		c, _ := WithCancel(parent)
   524  		<-c.Done()
   525  		if got, want := c.Err(), Canceled; got != want {
   526  			t.Errorf("child not canceled; got = %v, want = %v", got, want)
   527  		}
   528  		if got, want := Cause(c), cause; got != want {
   529  			t.Errorf("child has wrong cause; got = %v, want = %v", got, want)
   530  		}
   531  	}
   532  }
   533  
   534  func TestWithValueChecksKey(t *testing.T) {
   535  	panicVal := recoveredValue(func() { _ = WithValue(Background(), []byte("foo"), "bar") })
   536  	if panicVal == nil {
   537  		t.Error("expected panic")
   538  	}
   539  	panicVal = recoveredValue(func() { _ = WithValue(Background(), nil, "bar") })
   540  	if got, want := fmt.Sprint(panicVal), "nil key"; got != want {
   541  		t.Errorf("panic = %q; want %q", got, want)
   542  	}
   543  }
   544  
   545  func TestInvalidDerivedFail(t *testing.T) {
   546  	panicVal := recoveredValue(func() { _, _ = WithCancel(nil) })
   547  	if panicVal == nil {
   548  		t.Error("expected panic")
   549  	}
   550  	panicVal = recoveredValue(func() { _, _ = WithDeadline(nil, time.Now().Add(shortDuration)) })
   551  	if panicVal == nil {
   552  		t.Error("expected panic")
   553  	}
   554  	panicVal = recoveredValue(func() { _ = WithValue(nil, "foo", "bar") })
   555  	if panicVal == nil {
   556  		t.Error("expected panic")
   557  	}
   558  }
   559  
   560  func recoveredValue(fn func()) (v any) {
   561  	defer func() { v = recover() }()
   562  	fn()
   563  	return
   564  }
   565  
   566  func TestDeadlineExceededSupportsTimeout(t *testing.T) {
   567  	i, ok := DeadlineExceeded.(interface {
   568  		Timeout() bool
   569  	})
   570  	if !ok {
   571  		t.Fatal("DeadlineExceeded does not support Timeout interface")
   572  	}
   573  	if !i.Timeout() {
   574  		t.Fatal("wrong value for timeout")
   575  	}
   576  }
   577  func TestCause(t *testing.T) {
   578  	var (
   579  		forever       = 1e6 * time.Second
   580  		parentCause   = fmt.Errorf("parentCause")
   581  		childCause    = fmt.Errorf("childCause")
   582  		tooSlow       = fmt.Errorf("tooSlow")
   583  		finishedEarly = fmt.Errorf("finishedEarly")
   584  	)
   585  	for _, test := range []struct {
   586  		name  string
   587  		ctx   func() Context
   588  		err   error
   589  		cause error
   590  	}{
   591  		{
   592  			name:  "Background",
   593  			ctx:   Background,
   594  			err:   nil,
   595  			cause: nil,
   596  		},
   597  		{
   598  			name:  "TODO",
   599  			ctx:   TODO,
   600  			err:   nil,
   601  			cause: nil,
   602  		},
   603  		{
   604  			name: "WithCancel",
   605  			ctx: func() Context {
   606  				ctx, cancel := WithCancel(Background())
   607  				cancel()
   608  				return ctx
   609  			},
   610  			err:   Canceled,
   611  			cause: Canceled,
   612  		},
   613  		{
   614  			name: "WithCancelCause",
   615  			ctx: func() Context {
   616  				ctx, cancel := WithCancelCause(Background())
   617  				cancel(parentCause)
   618  				return ctx
   619  			},
   620  			err:   Canceled,
   621  			cause: parentCause,
   622  		},
   623  		{
   624  			name: "WithCancelCause nil",
   625  			ctx: func() Context {
   626  				ctx, cancel := WithCancelCause(Background())
   627  				cancel(nil)
   628  				return ctx
   629  			},
   630  			err:   Canceled,
   631  			cause: Canceled,
   632  		},
   633  		{
   634  			name: "WithCancelCause: parent cause before child",
   635  			ctx: func() Context {
   636  				ctx, cancelParent := WithCancelCause(Background())
   637  				ctx, cancelChild := WithCancelCause(ctx)
   638  				cancelParent(parentCause)
   639  				cancelChild(childCause)
   640  				return ctx
   641  			},
   642  			err:   Canceled,
   643  			cause: parentCause,
   644  		},
   645  		{
   646  			name: "WithCancelCause: parent cause after child",
   647  			ctx: func() Context {
   648  				ctx, cancelParent := WithCancelCause(Background())
   649  				ctx, cancelChild := WithCancelCause(ctx)
   650  				cancelChild(childCause)
   651  				cancelParent(parentCause)
   652  				return ctx
   653  			},
   654  			err:   Canceled,
   655  			cause: childCause,
   656  		},
   657  		{
   658  			name: "WithCancelCause: parent cause before nil",
   659  			ctx: func() Context {
   660  				ctx, cancelParent := WithCancelCause(Background())
   661  				ctx, cancelChild := WithCancel(ctx)
   662  				cancelParent(parentCause)
   663  				cancelChild()
   664  				return ctx
   665  			},
   666  			err:   Canceled,
   667  			cause: parentCause,
   668  		},
   669  		{
   670  			name: "WithCancelCause: parent cause after nil",
   671  			ctx: func() Context {
   672  				ctx, cancelParent := WithCancelCause(Background())
   673  				ctx, cancelChild := WithCancel(ctx)
   674  				cancelChild()
   675  				cancelParent(parentCause)
   676  				return ctx
   677  			},
   678  			err:   Canceled,
   679  			cause: Canceled,
   680  		},
   681  		{
   682  			name: "WithCancelCause: child cause after nil",
   683  			ctx: func() Context {
   684  				ctx, cancelParent := WithCancel(Background())
   685  				ctx, cancelChild := WithCancelCause(ctx)
   686  				cancelParent()
   687  				cancelChild(childCause)
   688  				return ctx
   689  			},
   690  			err:   Canceled,
   691  			cause: Canceled,
   692  		},
   693  		{
   694  			name: "WithCancelCause: child cause before nil",
   695  			ctx: func() Context {
   696  				ctx, cancelParent := WithCancel(Background())
   697  				ctx, cancelChild := WithCancelCause(ctx)
   698  				cancelChild(childCause)
   699  				cancelParent()
   700  				return ctx
   701  			},
   702  			err:   Canceled,
   703  			cause: childCause,
   704  		},
   705  		{
   706  			name: "WithTimeout",
   707  			ctx: func() Context {
   708  				ctx, cancel := WithTimeout(Background(), 0)
   709  				cancel()
   710  				return ctx
   711  			},
   712  			err:   DeadlineExceeded,
   713  			cause: DeadlineExceeded,
   714  		},
   715  		{
   716  			name: "WithTimeout canceled",
   717  			ctx: func() Context {
   718  				ctx, cancel := WithTimeout(Background(), forever)
   719  				cancel()
   720  				return ctx
   721  			},
   722  			err:   Canceled,
   723  			cause: Canceled,
   724  		},
   725  		{
   726  			name: "WithTimeoutCause",
   727  			ctx: func() Context {
   728  				ctx, cancel := WithTimeoutCause(Background(), 0, tooSlow)
   729  				cancel()
   730  				return ctx
   731  			},
   732  			err:   DeadlineExceeded,
   733  			cause: tooSlow,
   734  		},
   735  		{
   736  			name: "WithTimeoutCause canceled",
   737  			ctx: func() Context {
   738  				ctx, cancel := WithTimeoutCause(Background(), forever, tooSlow)
   739  				cancel()
   740  				return ctx
   741  			},
   742  			err:   Canceled,
   743  			cause: Canceled,
   744  		},
   745  		{
   746  			name: "WithTimeoutCause stacked",
   747  			ctx: func() Context {
   748  				ctx, cancel := WithCancelCause(Background())
   749  				ctx, _ = WithTimeoutCause(ctx, 0, tooSlow)
   750  				cancel(finishedEarly)
   751  				return ctx
   752  			},
   753  			err:   DeadlineExceeded,
   754  			cause: tooSlow,
   755  		},
   756  		{
   757  			name: "WithTimeoutCause stacked canceled",
   758  			ctx: func() Context {
   759  				ctx, cancel := WithCancelCause(Background())
   760  				ctx, _ = WithTimeoutCause(ctx, forever, tooSlow)
   761  				cancel(finishedEarly)
   762  				return ctx
   763  			},
   764  			err:   Canceled,
   765  			cause: finishedEarly,
   766  		},
   767  		{
   768  			name: "WithoutCancel",
   769  			ctx: func() Context {
   770  				return WithoutCancel(Background())
   771  			},
   772  			err:   nil,
   773  			cause: nil,
   774  		},
   775  		{
   776  			name: "WithoutCancel canceled",
   777  			ctx: func() Context {
   778  				ctx, cancel := WithCancelCause(Background())
   779  				ctx = WithoutCancel(ctx)
   780  				cancel(finishedEarly)
   781  				return ctx
   782  			},
   783  			err:   nil,
   784  			cause: nil,
   785  		},
   786  		{
   787  			name: "WithoutCancel timeout",
   788  			ctx: func() Context {
   789  				ctx, cancel := WithTimeoutCause(Background(), 0, tooSlow)
   790  				ctx = WithoutCancel(ctx)
   791  				cancel()
   792  				return ctx
   793  			},
   794  			err:   nil,
   795  			cause: nil,
   796  		},
   797  	} {
   798  		test := test
   799  		t.Run(test.name, func(t *testing.T) {
   800  			t.Parallel()
   801  			ctx := test.ctx()
   802  			if got, want := ctx.Err(), test.err; want != got {
   803  				t.Errorf("ctx.Err() = %v want %v", got, want)
   804  			}
   805  			if got, want := Cause(ctx), test.cause; want != got {
   806  				t.Errorf("Cause(ctx) = %v want %v", got, want)
   807  			}
   808  		})
   809  	}
   810  }
   811  
   812  func TestCauseRace(t *testing.T) {
   813  	cause := errors.New("TestCauseRace")
   814  	ctx, cancel := WithCancelCause(Background())
   815  	go func() {
   816  		cancel(cause)
   817  	}()
   818  	for {
   819  		// Poll Cause, rather than waiting for Done, to test that
   820  		// access to the underlying cause is synchronized properly.
   821  		if err := Cause(ctx); err != nil {
   822  			if err != cause {
   823  				t.Errorf("Cause returned %v, want %v", err, cause)
   824  			}
   825  			break
   826  		}
   827  		runtime.Gosched()
   828  	}
   829  }
   830  
   831  func TestWithoutCancel(t *testing.T) {
   832  	key, value := "key", "value"
   833  	ctx := WithValue(Background(), key, value)
   834  	ctx = WithoutCancel(ctx)
   835  	if d, ok := ctx.Deadline(); !d.IsZero() || ok != false {
   836  		t.Errorf("ctx.Deadline() = %v, %v want zero, false", d, ok)
   837  	}
   838  	if done := ctx.Done(); done != nil {
   839  		t.Errorf("ctx.Deadline() = %v want nil", done)
   840  	}
   841  	if err := ctx.Err(); err != nil {
   842  		t.Errorf("ctx.Err() = %v want nil", err)
   843  	}
   844  	if v := ctx.Value(key); v != value {
   845  		t.Errorf("ctx.Value(%q) = %q want %q", key, v, value)
   846  	}
   847  }
   848  
   849  type customDoneContext struct {
   850  	Context
   851  	donec chan struct{}
   852  }
   853  
   854  func (c *customDoneContext) Done() <-chan struct{} {
   855  	return c.donec
   856  }
   857  
   858  func TestCustomContextPropagation(t *testing.T) {
   859  	cause := errors.New("TestCustomContextPropagation")
   860  	donec := make(chan struct{})
   861  	ctx1, cancel1 := WithCancelCause(Background())
   862  	ctx2 := &customDoneContext{
   863  		Context: ctx1,
   864  		donec:   donec,
   865  	}
   866  	ctx3, cancel3 := WithCancel(ctx2)
   867  	defer cancel3()
   868  
   869  	cancel1(cause)
   870  	close(donec)
   871  
   872  	<-ctx3.Done()
   873  	if got, want := ctx3.Err(), Canceled; got != want {
   874  		t.Errorf("child not canceled; got = %v, want = %v", got, want)
   875  	}
   876  	if got, want := Cause(ctx3), cause; got != want {
   877  		t.Errorf("child has wrong cause; got = %v, want = %v", got, want)
   878  	}
   879  }
   880  
   881  // customCauseContext is a custom Context used to test context.Cause.
   882  type customCauseContext struct {
   883  	mu   sync.Mutex
   884  	done chan struct{}
   885  	err  error
   886  
   887  	cancelChild CancelFunc
   888  }
   889  
   890  func (ccc *customCauseContext) Deadline() (deadline time.Time, ok bool) {
   891  	return
   892  }
   893  
   894  func (ccc *customCauseContext) Done() <-chan struct{} {
   895  	ccc.mu.Lock()
   896  	defer ccc.mu.Unlock()
   897  	return ccc.done
   898  }
   899  
   900  func (ccc *customCauseContext) Err() error {
   901  	ccc.mu.Lock()
   902  	defer ccc.mu.Unlock()
   903  	return ccc.err
   904  }
   905  
   906  func (ccc *customCauseContext) Value(key any) any {
   907  	return nil
   908  }
   909  
   910  func (ccc *customCauseContext) cancel() {
   911  	ccc.mu.Lock()
   912  	ccc.err = Canceled
   913  	close(ccc.done)
   914  	cancelChild := ccc.cancelChild
   915  	ccc.mu.Unlock()
   916  
   917  	if cancelChild != nil {
   918  		cancelChild()
   919  	}
   920  }
   921  
   922  func (ccc *customCauseContext) setCancelChild(cancelChild CancelFunc) {
   923  	ccc.cancelChild = cancelChild
   924  }
   925  
   926  func TestCustomContextCause(t *testing.T) {
   927  	// Test if we cancel a custom context, Err and Cause return Canceled.
   928  	ccc := &customCauseContext{
   929  		done: make(chan struct{}),
   930  	}
   931  	ccc.cancel()
   932  	if got := ccc.Err(); got != Canceled {
   933  		t.Errorf("ccc.Err() = %v, want %v", got, Canceled)
   934  	}
   935  	if got := Cause(ccc); got != Canceled {
   936  		t.Errorf("Cause(ccc) = %v, want %v", got, Canceled)
   937  	}
   938  
   939  	// Test that if we pass a custom context to WithCancelCause,
   940  	// and then cancel that child context with a cause,
   941  	// that the cause of the child canceled context is correct
   942  	// but that the parent custom context is not canceled.
   943  	ccc = &customCauseContext{
   944  		done: make(chan struct{}),
   945  	}
   946  	ctx, causeFunc := WithCancelCause(ccc)
   947  	cause := errors.New("TestCustomContextCause")
   948  	causeFunc(cause)
   949  	if got := ctx.Err(); got != Canceled {
   950  		t.Errorf("after CancelCauseFunc ctx.Err() = %v, want %v", got, Canceled)
   951  	}
   952  	if got := Cause(ctx); got != cause {
   953  		t.Errorf("after CancelCauseFunc Cause(ctx) = %v, want %v", got, cause)
   954  	}
   955  	if got := ccc.Err(); got != nil {
   956  		t.Errorf("after CancelCauseFunc ccc.Err() = %v, want %v", got, nil)
   957  	}
   958  	if got := Cause(ccc); got != nil {
   959  		t.Errorf("after CancelCauseFunc Cause(ccc) = %v, want %v", got, nil)
   960  	}
   961  
   962  	// Test that if we now cancel the parent custom context,
   963  	// the cause of the child canceled context is still correct,
   964  	// and the parent custom context is canceled without a cause.
   965  	ccc.cancel()
   966  	if got := ctx.Err(); got != Canceled {
   967  		t.Errorf("after CancelCauseFunc ctx.Err() = %v, want %v", got, Canceled)
   968  	}
   969  	if got := Cause(ctx); got != cause {
   970  		t.Errorf("after CancelCauseFunc Cause(ctx) = %v, want %v", got, cause)
   971  	}
   972  	if got := ccc.Err(); got != Canceled {
   973  		t.Errorf("after CancelCauseFunc ccc.Err() = %v, want %v", got, Canceled)
   974  	}
   975  	if got := Cause(ccc); got != Canceled {
   976  		t.Errorf("after CancelCauseFunc Cause(ccc) = %v, want %v", got, Canceled)
   977  	}
   978  
   979  	// Test that if we associate a custom context with a child,
   980  	// then canceling the custom context cancels the child.
   981  	ccc = &customCauseContext{
   982  		done: make(chan struct{}),
   983  	}
   984  	ctx, cancelFunc := WithCancel(ccc)
   985  	ccc.setCancelChild(cancelFunc)
   986  	ccc.cancel()
   987  	if got := ctx.Err(); got != Canceled {
   988  		t.Errorf("after CancelCauseFunc ctx.Err() = %v, want %v", got, Canceled)
   989  	}
   990  	if got := Cause(ctx); got != Canceled {
   991  		t.Errorf("after CancelCauseFunc Cause(ctx) = %v, want %v", got, Canceled)
   992  	}
   993  	if got := ccc.Err(); got != Canceled {
   994  		t.Errorf("after CancelCauseFunc ccc.Err() = %v, want %v", got, Canceled)
   995  	}
   996  	if got := Cause(ccc); got != Canceled {
   997  		t.Errorf("after CancelCauseFunc Cause(ccc) = %v, want %v", got, Canceled)
   998  	}
   999  }
  1000  
  1001  func TestAfterFuncCalledAfterCancel(t *testing.T) {
  1002  	ctx, cancel := WithCancel(Background())
  1003  	donec := make(chan struct{})
  1004  	stop := AfterFunc(ctx, func() {
  1005  		close(donec)
  1006  	})
  1007  	select {
  1008  	case <-donec:
  1009  		t.Fatalf("AfterFunc called before context is done")
  1010  	case <-time.After(shortDuration):
  1011  	}
  1012  	cancel()
  1013  	select {
  1014  	case <-donec:
  1015  	case <-time.After(veryLongDuration):
  1016  		t.Fatalf("AfterFunc not called after context is canceled")
  1017  	}
  1018  	if stop() {
  1019  		t.Fatalf("stop() = true, want false")
  1020  	}
  1021  }
  1022  
  1023  func TestAfterFuncCalledAfterTimeout(t *testing.T) {
  1024  	ctx, cancel := WithTimeout(Background(), shortDuration)
  1025  	defer cancel()
  1026  	donec := make(chan struct{})
  1027  	AfterFunc(ctx, func() {
  1028  		close(donec)
  1029  	})
  1030  	select {
  1031  	case <-donec:
  1032  	case <-time.After(veryLongDuration):
  1033  		t.Fatalf("AfterFunc not called after context is canceled")
  1034  	}
  1035  }
  1036  
  1037  func TestAfterFuncCalledImmediately(t *testing.T) {
  1038  	ctx, cancel := WithCancel(Background())
  1039  	cancel()
  1040  	donec := make(chan struct{})
  1041  	AfterFunc(ctx, func() {
  1042  		close(donec)
  1043  	})
  1044  	select {
  1045  	case <-donec:
  1046  	case <-time.After(veryLongDuration):
  1047  		t.Fatalf("AfterFunc not called for already-canceled context")
  1048  	}
  1049  }
  1050  
  1051  func TestAfterFuncNotCalledAfterStop(t *testing.T) {
  1052  	ctx, cancel := WithCancel(Background())
  1053  	donec := make(chan struct{})
  1054  	stop := AfterFunc(ctx, func() {
  1055  		close(donec)
  1056  	})
  1057  	if !stop() {
  1058  		t.Fatalf("stop() = false, want true")
  1059  	}
  1060  	cancel()
  1061  	select {
  1062  	case <-donec:
  1063  		t.Fatalf("AfterFunc called for already-canceled context")
  1064  	case <-time.After(shortDuration):
  1065  	}
  1066  	if stop() {
  1067  		t.Fatalf("stop() = true, want false")
  1068  	}
  1069  }
  1070  
  1071  // This test verifies that canceling a context does not block waiting for AfterFuncs to finish.
  1072  func TestAfterFuncCalledAsynchronously(t *testing.T) {
  1073  	ctx, cancel := WithCancel(Background())
  1074  	donec := make(chan struct{})
  1075  	stop := AfterFunc(ctx, func() {
  1076  		// The channel send blocks until donec is read from.
  1077  		donec <- struct{}{}
  1078  	})
  1079  	defer stop()
  1080  	cancel()
  1081  	// After cancel returns, read from donec and unblock the AfterFunc.
  1082  	select {
  1083  	case <-donec:
  1084  	case <-time.After(veryLongDuration):
  1085  		t.Fatalf("AfterFunc not called after context is canceled")
  1086  	}
  1087  }
  1088  

View as plain text