1// Copyright 2024 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package concurrent
6
7import (
8	"fmt"
9	"math"
10	"runtime"
11	"strconv"
12	"strings"
13	"sync"
14	"testing"
15	"unsafe"
16)
17
18func TestHashTrieMap(t *testing.T) {
19	testHashTrieMap(t, func() *HashTrieMap[string, int] {
20		return NewHashTrieMap[string, int]()
21	})
22}
23
24func TestHashTrieMapBadHash(t *testing.T) {
25	testHashTrieMap(t, func() *HashTrieMap[string, int] {
26		// Stub out the good hash function with a terrible one.
27		// Everything should still work as expected.
28		m := NewHashTrieMap[string, int]()
29		m.keyHash = func(_ unsafe.Pointer, _ uintptr) uintptr {
30			return 0
31		}
32		return m
33	})
34}
35
36func testHashTrieMap(t *testing.T, newMap func() *HashTrieMap[string, int]) {
37	t.Run("LoadEmpty", func(t *testing.T) {
38		m := newMap()
39
40		for _, s := range testData {
41			expectMissing(t, s, 0)(m.Load(s))
42		}
43	})
44	t.Run("LoadOrStore", func(t *testing.T) {
45		m := newMap()
46
47		for i, s := range testData {
48			expectMissing(t, s, 0)(m.Load(s))
49			expectStored(t, s, i)(m.LoadOrStore(s, i))
50			expectPresent(t, s, i)(m.Load(s))
51			expectLoaded(t, s, i)(m.LoadOrStore(s, 0))
52		}
53		for i, s := range testData {
54			expectPresent(t, s, i)(m.Load(s))
55			expectLoaded(t, s, i)(m.LoadOrStore(s, 0))
56		}
57	})
58	t.Run("CompareAndDeleteAll", func(t *testing.T) {
59		m := newMap()
60
61		for range 3 {
62			for i, s := range testData {
63				expectMissing(t, s, 0)(m.Load(s))
64				expectStored(t, s, i)(m.LoadOrStore(s, i))
65				expectPresent(t, s, i)(m.Load(s))
66				expectLoaded(t, s, i)(m.LoadOrStore(s, 0))
67			}
68			for i, s := range testData {
69				expectPresent(t, s, i)(m.Load(s))
70				expectNotDeleted(t, s, math.MaxInt)(m.CompareAndDelete(s, math.MaxInt))
71				expectDeleted(t, s, i)(m.CompareAndDelete(s, i))
72				expectNotDeleted(t, s, i)(m.CompareAndDelete(s, i))
73				expectMissing(t, s, 0)(m.Load(s))
74			}
75			for _, s := range testData {
76				expectMissing(t, s, 0)(m.Load(s))
77			}
78		}
79	})
80	t.Run("CompareAndDeleteOne", func(t *testing.T) {
81		m := newMap()
82
83		for i, s := range testData {
84			expectMissing(t, s, 0)(m.Load(s))
85			expectStored(t, s, i)(m.LoadOrStore(s, i))
86			expectPresent(t, s, i)(m.Load(s))
87			expectLoaded(t, s, i)(m.LoadOrStore(s, 0))
88		}
89		expectNotDeleted(t, testData[15], math.MaxInt)(m.CompareAndDelete(testData[15], math.MaxInt))
90		expectDeleted(t, testData[15], 15)(m.CompareAndDelete(testData[15], 15))
91		expectNotDeleted(t, testData[15], 15)(m.CompareAndDelete(testData[15], 15))
92		for i, s := range testData {
93			if i == 15 {
94				expectMissing(t, s, 0)(m.Load(s))
95			} else {
96				expectPresent(t, s, i)(m.Load(s))
97			}
98		}
99	})
100	t.Run("DeleteMultiple", func(t *testing.T) {
101		m := newMap()
102
103		for i, s := range testData {
104			expectMissing(t, s, 0)(m.Load(s))
105			expectStored(t, s, i)(m.LoadOrStore(s, i))
106			expectPresent(t, s, i)(m.Load(s))
107			expectLoaded(t, s, i)(m.LoadOrStore(s, 0))
108		}
109		for _, i := range []int{1, 105, 6, 85} {
110			expectNotDeleted(t, testData[i], math.MaxInt)(m.CompareAndDelete(testData[i], math.MaxInt))
111			expectDeleted(t, testData[i], i)(m.CompareAndDelete(testData[i], i))
112			expectNotDeleted(t, testData[i], i)(m.CompareAndDelete(testData[i], i))
113		}
114		for i, s := range testData {
115			if i == 1 || i == 105 || i == 6 || i == 85 {
116				expectMissing(t, s, 0)(m.Load(s))
117			} else {
118				expectPresent(t, s, i)(m.Load(s))
119			}
120		}
121	})
122	t.Run("All", func(t *testing.T) {
123		m := newMap()
124
125		testAll(t, m, testDataMap(testData[:]), func(_ string, _ int) bool {
126			return true
127		})
128	})
129	t.Run("AllDelete", func(t *testing.T) {
130		m := newMap()
131
132		testAll(t, m, testDataMap(testData[:]), func(s string, i int) bool {
133			expectDeleted(t, s, i)(m.CompareAndDelete(s, i))
134			return true
135		})
136		for _, s := range testData {
137			expectMissing(t, s, 0)(m.Load(s))
138		}
139	})
140	t.Run("ConcurrentLifecycleUnsharedKeys", func(t *testing.T) {
141		m := newMap()
142
143		gmp := runtime.GOMAXPROCS(-1)
144		var wg sync.WaitGroup
145		for i := range gmp {
146			wg.Add(1)
147			go func(id int) {
148				defer wg.Done()
149
150				makeKey := func(s string) string {
151					return s + "-" + strconv.Itoa(id)
152				}
153				for _, s := range testData {
154					key := makeKey(s)
155					expectMissing(t, key, 0)(m.Load(key))
156					expectStored(t, key, id)(m.LoadOrStore(key, id))
157					expectPresent(t, key, id)(m.Load(key))
158					expectLoaded(t, key, id)(m.LoadOrStore(key, 0))
159				}
160				for _, s := range testData {
161					key := makeKey(s)
162					expectPresent(t, key, id)(m.Load(key))
163					expectDeleted(t, key, id)(m.CompareAndDelete(key, id))
164					expectMissing(t, key, 0)(m.Load(key))
165				}
166				for _, s := range testData {
167					key := makeKey(s)
168					expectMissing(t, key, 0)(m.Load(key))
169				}
170			}(i)
171		}
172		wg.Wait()
173	})
174	t.Run("ConcurrentDeleteSharedKeys", func(t *testing.T) {
175		m := newMap()
176
177		// Load up the map.
178		for i, s := range testData {
179			expectMissing(t, s, 0)(m.Load(s))
180			expectStored(t, s, i)(m.LoadOrStore(s, i))
181		}
182		gmp := runtime.GOMAXPROCS(-1)
183		var wg sync.WaitGroup
184		for i := range gmp {
185			wg.Add(1)
186			go func(id int) {
187				defer wg.Done()
188
189				for i, s := range testData {
190					expectNotDeleted(t, s, math.MaxInt)(m.CompareAndDelete(s, math.MaxInt))
191					m.CompareAndDelete(s, i)
192					expectMissing(t, s, 0)(m.Load(s))
193				}
194				for _, s := range testData {
195					expectMissing(t, s, 0)(m.Load(s))
196				}
197			}(i)
198		}
199		wg.Wait()
200	})
201}
202
203func testAll[K, V comparable](t *testing.T, m *HashTrieMap[K, V], testData map[K]V, yield func(K, V) bool) {
204	for k, v := range testData {
205		expectStored(t, k, v)(m.LoadOrStore(k, v))
206	}
207	visited := make(map[K]int)
208	m.All()(func(key K, got V) bool {
209		want, ok := testData[key]
210		if !ok {
211			t.Errorf("unexpected key %v in map", key)
212			return false
213		}
214		if got != want {
215			t.Errorf("expected key %v to have value %v, got %v", key, want, got)
216			return false
217		}
218		visited[key]++
219		return yield(key, got)
220	})
221	for key, n := range visited {
222		if n > 1 {
223			t.Errorf("visited key %v more than once", key)
224		}
225	}
226}
227
228func expectPresent[K, V comparable](t *testing.T, key K, want V) func(got V, ok bool) {
229	t.Helper()
230	return func(got V, ok bool) {
231		t.Helper()
232
233		if !ok {
234			t.Errorf("expected key %v to be present in map", key)
235		}
236		if ok && got != want {
237			t.Errorf("expected key %v to have value %v, got %v", key, want, got)
238		}
239	}
240}
241
242func expectMissing[K, V comparable](t *testing.T, key K, want V) func(got V, ok bool) {
243	t.Helper()
244	if want != *new(V) {
245		// This is awkward, but the want argument is necessary to smooth over type inference.
246		// Just make sure the want argument always looks the same.
247		panic("expectMissing must always have a zero value variable")
248	}
249	return func(got V, ok bool) {
250		t.Helper()
251
252		if ok {
253			t.Errorf("expected key %v to be missing from map, got value %v", key, got)
254		}
255		if !ok && got != want {
256			t.Errorf("expected missing key %v to be paired with the zero value; got %v", key, got)
257		}
258	}
259}
260
261func expectLoaded[K, V comparable](t *testing.T, key K, want V) func(got V, loaded bool) {
262	t.Helper()
263	return func(got V, loaded bool) {
264		t.Helper()
265
266		if !loaded {
267			t.Errorf("expected key %v to have been loaded, not stored", key)
268		}
269		if got != want {
270			t.Errorf("expected key %v to have value %v, got %v", key, want, got)
271		}
272	}
273}
274
275func expectStored[K, V comparable](t *testing.T, key K, want V) func(got V, loaded bool) {
276	t.Helper()
277	return func(got V, loaded bool) {
278		t.Helper()
279
280		if loaded {
281			t.Errorf("expected inserted key %v to have been stored, not loaded", key)
282		}
283		if got != want {
284			t.Errorf("expected inserted key %v to have value %v, got %v", key, want, got)
285		}
286	}
287}
288
289func expectDeleted[K, V comparable](t *testing.T, key K, old V) func(deleted bool) {
290	t.Helper()
291	return func(deleted bool) {
292		t.Helper()
293
294		if !deleted {
295			t.Errorf("expected key %v with value %v to be in map and deleted", key, old)
296		}
297	}
298}
299
300func expectNotDeleted[K, V comparable](t *testing.T, key K, old V) func(deleted bool) {
301	t.Helper()
302	return func(deleted bool) {
303		t.Helper()
304
305		if deleted {
306			t.Errorf("expected key %v with value %v to not be in map and thus not deleted", key, old)
307		}
308	}
309}
310
311func testDataMap(data []string) map[string]int {
312	m := make(map[string]int)
313	for i, s := range data {
314		m[s] = i
315	}
316	return m
317}
318
319var (
320	testDataSmall [8]string
321	testData      [128]string
322	testDataLarge [128 << 10]string
323)
324
325func init() {
326	for i := range testDataSmall {
327		testDataSmall[i] = fmt.Sprintf("%b", i)
328	}
329	for i := range testData {
330		testData[i] = fmt.Sprintf("%b", i)
331	}
332	for i := range testDataLarge {
333		testDataLarge[i] = fmt.Sprintf("%b", i)
334	}
335}
336
337func dumpMap[K, V comparable](ht *HashTrieMap[K, V]) {
338	dumpNode(ht, &ht.root.node, 0)
339}
340
341func dumpNode[K, V comparable](ht *HashTrieMap[K, V], n *node[K, V], depth int) {
342	var sb strings.Builder
343	for range depth {
344		fmt.Fprintf(&sb, "\t")
345	}
346	prefix := sb.String()
347	if n.isEntry {
348		e := n.entry()
349		for e != nil {
350			fmt.Printf("%s%p [Entry Key=%v Value=%v Overflow=%p, Hash=%016x]\n", prefix, e, e.key, e.value, e.overflow.Load(), ht.keyHash(unsafe.Pointer(&e.key), ht.seed))
351			e = e.overflow.Load()
352		}
353		return
354	}
355	i := n.indirect()
356	fmt.Printf("%s%p [Indirect Parent=%p Dead=%t Children=[", prefix, i, i.parent, i.dead.Load())
357	for j := range i.children {
358		c := i.children[j].Load()
359		fmt.Printf("%p", c)
360		if j != len(i.children)-1 {
361			fmt.Printf(", ")
362		}
363	}
364	fmt.Printf("]]\n")
365	for j := range i.children {
366		c := i.children[j].Load()
367		if c != nil {
368			dumpNode(ht, c, depth+1)
369		}
370	}
371}
372