1// Copyright 2023 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package context_test
6
7import (
8	"context"
9	"sync"
10	"testing"
11	"time"
12)
13
14// afterFuncContext is a context that's not one of the types
15// defined in context.go, that supports registering AfterFuncs.
16type afterFuncContext struct {
17	mu         sync.Mutex
18	afterFuncs map[*byte]func()
19	done       chan struct{}
20	err        error
21}
22
23func newAfterFuncContext() context.Context {
24	return &afterFuncContext{}
25}
26
27func (c *afterFuncContext) Deadline() (time.Time, bool) {
28	return time.Time{}, false
29}
30
31func (c *afterFuncContext) Done() <-chan struct{} {
32	c.mu.Lock()
33	defer c.mu.Unlock()
34	if c.done == nil {
35		c.done = make(chan struct{})
36	}
37	return c.done
38}
39
40func (c *afterFuncContext) Err() error {
41	c.mu.Lock()
42	defer c.mu.Unlock()
43	return c.err
44}
45
46func (c *afterFuncContext) Value(key any) any {
47	return nil
48}
49
50func (c *afterFuncContext) AfterFunc(f func()) func() bool {
51	c.mu.Lock()
52	defer c.mu.Unlock()
53	k := new(byte)
54	if c.afterFuncs == nil {
55		c.afterFuncs = make(map[*byte]func())
56	}
57	c.afterFuncs[k] = f
58	return func() bool {
59		c.mu.Lock()
60		defer c.mu.Unlock()
61		_, ok := c.afterFuncs[k]
62		delete(c.afterFuncs, k)
63		return ok
64	}
65}
66
67func (c *afterFuncContext) cancel(err error) {
68	c.mu.Lock()
69	defer c.mu.Unlock()
70	if c.err != nil {
71		return
72	}
73	c.err = err
74	for _, f := range c.afterFuncs {
75		go f()
76	}
77	c.afterFuncs = nil
78}
79
80func TestCustomContextAfterFuncCancel(t *testing.T) {
81	ctx0 := &afterFuncContext{}
82	ctx1, cancel := context.WithCancel(ctx0)
83	defer cancel()
84	ctx0.cancel(context.Canceled)
85	<-ctx1.Done()
86}
87
88func TestCustomContextAfterFuncTimeout(t *testing.T) {
89	ctx0 := &afterFuncContext{}
90	ctx1, cancel := context.WithTimeout(ctx0, veryLongDuration)
91	defer cancel()
92	ctx0.cancel(context.Canceled)
93	<-ctx1.Done()
94}
95
96func TestCustomContextAfterFuncAfterFunc(t *testing.T) {
97	ctx0 := &afterFuncContext{}
98	donec := make(chan struct{})
99	stop := context.AfterFunc(ctx0, func() {
100		close(donec)
101	})
102	defer stop()
103	ctx0.cancel(context.Canceled)
104	<-donec
105}
106
107func TestCustomContextAfterFuncUnregisterCancel(t *testing.T) {
108	ctx0 := &afterFuncContext{}
109	_, cancel1 := context.WithCancel(ctx0)
110	_, cancel2 := context.WithCancel(ctx0)
111	if got, want := len(ctx0.afterFuncs), 2; got != want {
112		t.Errorf("after WithCancel(ctx0): ctx0 has %v afterFuncs, want %v", got, want)
113	}
114	cancel1()
115	cancel2()
116	if got, want := len(ctx0.afterFuncs), 0; got != want {
117		t.Errorf("after canceling WithCancel(ctx0): ctx0 has %v afterFuncs, want %v", got, want)
118	}
119}
120
121func TestCustomContextAfterFuncUnregisterTimeout(t *testing.T) {
122	ctx0 := &afterFuncContext{}
123	_, cancel := context.WithTimeout(ctx0, veryLongDuration)
124	if got, want := len(ctx0.afterFuncs), 1; got != want {
125		t.Errorf("after WithTimeout(ctx0, d): ctx0 has %v afterFuncs, want %v", got, want)
126	}
127	cancel()
128	if got, want := len(ctx0.afterFuncs), 0; got != want {
129		t.Errorf("after canceling WithTimeout(ctx0, d): ctx0 has %v afterFuncs, want %v", got, want)
130	}
131}
132
133func TestCustomContextAfterFuncUnregisterAfterFunc(t *testing.T) {
134	ctx0 := &afterFuncContext{}
135	stop := context.AfterFunc(ctx0, func() {})
136	if got, want := len(ctx0.afterFuncs), 1; got != want {
137		t.Errorf("after AfterFunc(ctx0, f): ctx0 has %v afterFuncs, want %v", got, want)
138	}
139	stop()
140	if got, want := len(ctx0.afterFuncs), 0; got != want {
141		t.Errorf("after stopping AfterFunc(ctx0, f): ctx0 has %v afterFuncs, want %v", got, want)
142	}
143}
144