1// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package par implements parallel execution helpers.
6package par
7
8import (
9	"errors"
10	"math/rand"
11	"sync"
12	"sync/atomic"
13)
14
15// Work manages a set of work items to be executed in parallel, at most once each.
16// The items in the set must all be valid map keys.
17type Work[T comparable] struct {
18	f       func(T) // function to run for each item
19	running int     // total number of runners
20
21	mu      sync.Mutex
22	added   map[T]bool // items added to set
23	todo    []T        // items yet to be run
24	wait    sync.Cond  // wait when todo is empty
25	waiting int        // number of runners waiting for todo
26}
27
28func (w *Work[T]) init() {
29	if w.added == nil {
30		w.added = make(map[T]bool)
31	}
32}
33
34// Add adds item to the work set, if it hasn't already been added.
35func (w *Work[T]) Add(item T) {
36	w.mu.Lock()
37	w.init()
38	if !w.added[item] {
39		w.added[item] = true
40		w.todo = append(w.todo, item)
41		if w.waiting > 0 {
42			w.wait.Signal()
43		}
44	}
45	w.mu.Unlock()
46}
47
48// Do runs f in parallel on items from the work set,
49// with at most n invocations of f running at a time.
50// It returns when everything added to the work set has been processed.
51// At least one item should have been added to the work set
52// before calling Do (or else Do returns immediately),
53// but it is allowed for f(item) to add new items to the set.
54// Do should only be used once on a given Work.
55func (w *Work[T]) Do(n int, f func(item T)) {
56	if n < 1 {
57		panic("par.Work.Do: n < 1")
58	}
59	if w.running >= 1 {
60		panic("par.Work.Do: already called Do")
61	}
62
63	w.running = n
64	w.f = f
65	w.wait.L = &w.mu
66
67	for i := 0; i < n-1; i++ {
68		go w.runner()
69	}
70	w.runner()
71}
72
73// runner executes work in w until both nothing is left to do
74// and all the runners are waiting for work.
75// (Then all the runners return.)
76func (w *Work[T]) runner() {
77	for {
78		// Wait for something to do.
79		w.mu.Lock()
80		for len(w.todo) == 0 {
81			w.waiting++
82			if w.waiting == w.running {
83				// All done.
84				w.wait.Broadcast()
85				w.mu.Unlock()
86				return
87			}
88			w.wait.Wait()
89			w.waiting--
90		}
91
92		// Pick something to do at random,
93		// to eliminate pathological contention
94		// in case items added at about the same time
95		// are most likely to contend.
96		i := rand.Intn(len(w.todo))
97		item := w.todo[i]
98		w.todo[i] = w.todo[len(w.todo)-1]
99		w.todo = w.todo[:len(w.todo)-1]
100		w.mu.Unlock()
101
102		w.f(item)
103	}
104}
105
106// ErrCache is like Cache except that it also stores
107// an error value alongside the cached value V.
108type ErrCache[K comparable, V any] struct {
109	Cache[K, errValue[V]]
110}
111
112type errValue[V any] struct {
113	v   V
114	err error
115}
116
117func (c *ErrCache[K, V]) Do(key K, f func() (V, error)) (V, error) {
118	v := c.Cache.Do(key, func() errValue[V] {
119		v, err := f()
120		return errValue[V]{v, err}
121	})
122	return v.v, v.err
123}
124
125var ErrCacheEntryNotFound = errors.New("cache entry not found")
126
127// Get returns the cached result associated with key.
128// It returns ErrCacheEntryNotFound if there is no such result.
129func (c *ErrCache[K, V]) Get(key K) (V, error) {
130	v, ok := c.Cache.Get(key)
131	if !ok {
132		v.err = ErrCacheEntryNotFound
133	}
134	return v.v, v.err
135}
136
137// Cache runs an action once per key and caches the result.
138type Cache[K comparable, V any] struct {
139	m sync.Map
140}
141
142type cacheEntry[V any] struct {
143	done   atomic.Bool
144	mu     sync.Mutex
145	result V
146}
147
148// Do calls the function f if and only if Do is being called for the first time with this key.
149// No call to Do with a given key returns until the one call to f returns.
150// Do returns the value returned by the one call to f.
151func (c *Cache[K, V]) Do(key K, f func() V) V {
152	entryIface, ok := c.m.Load(key)
153	if !ok {
154		entryIface, _ = c.m.LoadOrStore(key, new(cacheEntry[V]))
155	}
156	e := entryIface.(*cacheEntry[V])
157	if !e.done.Load() {
158		e.mu.Lock()
159		if !e.done.Load() {
160			e.result = f()
161			e.done.Store(true)
162		}
163		e.mu.Unlock()
164	}
165	return e.result
166}
167
168// Get returns the cached result associated with key
169// and reports whether there is such a result.
170//
171// If the result for key is being computed, Get does not wait for the computation to finish.
172func (c *Cache[K, V]) Get(key K) (V, bool) {
173	entryIface, ok := c.m.Load(key)
174	if !ok {
175		return *new(V), false
176	}
177	e := entryIface.(*cacheEntry[V])
178	if !e.done.Load() {
179		return *new(V), false
180	}
181	return e.result, true
182}
183