1// Copyright 2023 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package internal/counter implements the internals of the public counter package.
6// In addition to the public API, this package also includes APIs to parse and
7// manage the counter files, needed by the upload package.
8package counter
9
10import (
11	"fmt"
12	"os"
13	"runtime"
14	"strings"
15	"sync/atomic"
16)
17
18var (
19	// Note: not using internal/godebug, so that internal/godebug can use
20	// internal/counter.
21	debugCounter = strings.Contains(os.Getenv("GODEBUG"), "countertrace=1")
22	CrashOnBugs  = false // for testing; if set, exit on fatal log messages
23)
24
25// debugPrintf formats a debug message if GODEBUG=countertrace=1.
26func debugPrintf(format string, args ...any) {
27	if debugCounter {
28		if len(format) == 0 || format[len(format)-1] != '\n' {
29			format += "\n"
30		}
31		fmt.Fprintf(os.Stderr, "counter: "+format, args...)
32	}
33}
34
35// debugFatalf logs a fatal error if GODEBUG=countertrace=1.
36func debugFatalf(format string, args ...any) {
37	if debugCounter || CrashOnBugs {
38		if len(format) == 0 || format[len(format)-1] != '\n' {
39			format += "\n"
40		}
41		fmt.Fprintf(os.Stderr, "counter bug: "+format, args...)
42		os.Exit(1)
43	}
44}
45
46// A Counter is a single named event counter.
47// A Counter is safe for use by multiple goroutines simultaneously.
48//
49// Counters should typically be created using New
50// and stored as global variables, like:
51//
52//	package mypackage
53//	var errorCount = counter.New("mypackage/errors")
54//
55// (The initialization of errorCount in this example is handled
56// entirely by the compiler and linker; this line executes no code
57// at program startup.)
58//
59// Then code can call Add to increment the counter
60// each time the corresponding event is observed.
61//
62// Although it is possible to use New to create
63// a Counter each time a particular event needs to be recorded,
64// that usage fails to amortize the construction cost over
65// multiple calls to Add, so it is more expensive and not recommended.
66type Counter struct {
67	name string
68	file *file
69
70	next  atomic.Pointer[Counter]
71	state counterState
72	ptr   counterPtr
73}
74
75func (c *Counter) Name() string {
76	return c.name
77}
78
79type counterPtr struct {
80	m     *mappedFile
81	count *atomic.Uint64
82}
83
84type counterState struct {
85	bits atomic.Uint64
86}
87
88func (s *counterState) load() counterStateBits {
89	return counterStateBits(s.bits.Load())
90}
91
92func (s *counterState) update(old *counterStateBits, new counterStateBits) bool {
93	if s.bits.CompareAndSwap(uint64(*old), uint64(new)) {
94		*old = new
95		return true
96	}
97	return false
98}
99
100type counterStateBits uint64
101
102const (
103	stateReaders    counterStateBits = 1<<30 - 1
104	stateLocked     counterStateBits = stateReaders
105	stateHavePtr    counterStateBits = 1 << 30
106	stateExtraShift                  = 31
107	stateExtra      counterStateBits = 1<<64 - 1<<stateExtraShift
108)
109
110func (b counterStateBits) readers() int  { return int(b & stateReaders) }
111func (b counterStateBits) locked() bool  { return b&stateReaders == stateLocked }
112func (b counterStateBits) havePtr() bool { return b&stateHavePtr != 0 }
113func (b counterStateBits) extra() uint64 { return uint64(b&stateExtra) >> stateExtraShift }
114
115func (b counterStateBits) incReader() counterStateBits    { return b + 1 }
116func (b counterStateBits) decReader() counterStateBits    { return b - 1 }
117func (b counterStateBits) setLocked() counterStateBits    { return b | stateLocked }
118func (b counterStateBits) clearLocked() counterStateBits  { return b &^ stateLocked }
119func (b counterStateBits) setHavePtr() counterStateBits   { return b | stateHavePtr }
120func (b counterStateBits) clearHavePtr() counterStateBits { return b &^ stateHavePtr }
121func (b counterStateBits) clearExtra() counterStateBits   { return b &^ stateExtra }
122func (b counterStateBits) addExtra(n uint64) counterStateBits {
123	const maxExtra = uint64(stateExtra) >> stateExtraShift // 0x1ffffffff
124	x := b.extra()
125	if x+n < x || x+n > maxExtra {
126		x = maxExtra
127	} else {
128		x += n
129	}
130	return b.clearExtra() | counterStateBits(x)<<stateExtraShift
131}
132
133// New returns a counter with the given name.
134// New can be called in global initializers and will be compiled down to
135// linker-initialized data. That is, calling New to initialize a global
136// has no cost at program startup.
137func New(name string) *Counter {
138	// Note: not calling defaultFile.New in order to keep this
139	// function something the compiler can inline and convert
140	// into static data initializations, with no init-time footprint.
141	return &Counter{name: name, file: &defaultFile}
142}
143
144// Inc adds 1 to the counter.
145func (c *Counter) Inc() {
146	c.Add(1)
147}
148
149// Add adds n to the counter. n cannot be negative, as counts cannot decrease.
150func (c *Counter) Add(n int64) {
151	debugPrintf("Add %q += %d", c.name, n)
152
153	if n < 0 {
154		panic("Counter.Add negative")
155	}
156	if n == 0 {
157		return
158	}
159	c.file.register(c)
160
161	state := c.state.load()
162	for ; ; state = c.state.load() {
163		switch {
164		case !state.locked() && state.havePtr():
165			if !c.state.update(&state, state.incReader()) {
166				continue
167			}
168			// Counter unlocked or counter shared; has an initialized count pointer; acquired shared lock.
169			if c.ptr.count == nil {
170				for !c.state.update(&state, state.addExtra(uint64(n))) {
171					// keep trying - we already took the reader lock
172					state = c.state.load()
173				}
174				debugPrintf("Add %q += %d: nil extra=%d\n", c.name, n, state.extra())
175			} else {
176				sum := c.add(uint64(n))
177				debugPrintf("Add %q += %d: count=%d\n", c.name, n, sum)
178			}
179			c.releaseReader(state)
180			return
181
182		case state.locked():
183			if !c.state.update(&state, state.addExtra(uint64(n))) {
184				continue
185			}
186			debugPrintf("Add %q += %d: locked extra=%d\n", c.name, n, state.extra())
187			return
188
189		case !state.havePtr():
190			if !c.state.update(&state, state.addExtra(uint64(n)).setLocked()) {
191				continue
192			}
193			debugPrintf("Add %q += %d: noptr extra=%d\n", c.name, n, state.extra())
194			c.releaseLock(state)
195			return
196		}
197	}
198}
199
200func (c *Counter) releaseReader(state counterStateBits) {
201	for ; ; state = c.state.load() {
202		// If we are the last reader and havePtr was cleared
203		// while this batch of readers was using c.ptr,
204		// it's our job to update c.ptr by upgrading to a full lock
205		// and letting releaseLock do the work.
206		// Note: no new reader will attempt to add itself now that havePtr is clear,
207		// so we are only racing against possible additions to extra.
208		if state.readers() == 1 && !state.havePtr() {
209			if !c.state.update(&state, state.setLocked()) {
210				continue
211			}
212			debugPrintf("releaseReader %s: last reader, need ptr\n", c.name)
213			c.releaseLock(state)
214			return
215		}
216
217		// Release reader.
218		if !c.state.update(&state, state.decReader()) {
219			continue
220		}
221		debugPrintf("releaseReader %s: released (%d readers now)\n", c.name, state.readers())
222		return
223	}
224}
225
226func (c *Counter) releaseLock(state counterStateBits) {
227	for ; ; state = c.state.load() {
228		if !state.havePtr() {
229			// Set havePtr before updating ptr,
230			// to avoid race with the next clear of havePtr.
231			if !c.state.update(&state, state.setHavePtr()) {
232				continue
233			}
234			debugPrintf("releaseLock %s: reset havePtr (extra=%d)\n", c.name, state.extra())
235
236			// Optimization: only bother loading a new pointer
237			// if we have a value to add to it.
238			c.ptr = counterPtr{nil, nil}
239			if state.extra() != 0 {
240				c.ptr = c.file.lookup(c.name)
241				debugPrintf("releaseLock %s: ptr=%v\n", c.name, c.ptr)
242			}
243		}
244
245		if extra := state.extra(); extra != 0 && c.ptr.count != nil {
246			if !c.state.update(&state, state.clearExtra()) {
247				continue
248			}
249			sum := c.add(extra)
250			debugPrintf("releaseLock %s: flush extra=%d -> count=%d\n", c.name, extra, sum)
251		}
252
253		// Took care of refreshing ptr and flushing extra.
254		// Now we can release the lock, unless of course
255		// another goroutine cleared havePtr or added to extra,
256		// in which case we go around again.
257		if !c.state.update(&state, state.clearLocked()) {
258			continue
259		}
260		debugPrintf("releaseLock %s: unlocked\n", c.name)
261		return
262	}
263}
264
265// add wraps the atomic.Uint64.Add operation to handle integer overflow.
266func (c *Counter) add(n uint64) uint64 {
267	count := c.ptr.count
268	for {
269		old := count.Load()
270		sum := old + n
271		if sum < old {
272			sum = ^uint64(0)
273		}
274		if count.CompareAndSwap(old, sum) {
275			runtime.KeepAlive(c.ptr.m)
276			return sum
277		}
278	}
279}
280
281func (c *Counter) invalidate() {
282	for {
283		state := c.state.load()
284		if !state.havePtr() {
285			debugPrintf("invalidate %s: no ptr\n", c.name)
286			return
287		}
288		if c.state.update(&state, state.clearHavePtr()) {
289			debugPrintf("invalidate %s: cleared havePtr\n", c.name)
290			return
291		}
292	}
293}
294
295func (c *Counter) refresh() {
296	for {
297		state := c.state.load()
298		if state.havePtr() || state.readers() > 0 || state.extra() == 0 {
299			debugPrintf("refresh %s: havePtr=%v readers=%d extra=%d\n", c.name, state.havePtr(), state.readers(), state.extra())
300			return
301		}
302		if c.state.update(&state, state.setLocked()) {
303			debugPrintf("refresh %s: locked havePtr=%v readers=%d extra=%d\n", c.name, state.havePtr(), state.readers(), state.extra())
304			c.releaseLock(state)
305			return
306		}
307	}
308}
309
310// Read reads the given counter.
311// This is the implementation of x/telemetry/counter/countertest.ReadCounter.
312func Read(c *Counter) (uint64, error) {
313	if c.file.current.Load() == nil {
314		return c.state.load().extra(), nil
315	}
316	pf, err := readFile(c.file)
317	if err != nil {
318		return 0, err
319	}
320	v, ok := pf.Count[DecodeStack(c.Name())]
321	if !ok {
322		return v, fmt.Errorf("not found:%q", DecodeStack(c.Name()))
323	}
324	return v, nil
325}
326
327func readFile(f *file) (*File, error) {
328	if f == nil {
329		debugPrintf("No file")
330		return nil, fmt.Errorf("counter is not initialized - was Open called?")
331	}
332
333	// Note: don't call f.rotate here as this will enqueue a follow-up rotation.
334	f.rotate1()
335
336	if f.err != nil {
337		return nil, fmt.Errorf("failed to rotate mapped file - %v", f.err)
338	}
339	current := f.current.Load()
340	if current == nil {
341		return nil, fmt.Errorf("counter has no mapped file")
342	}
343	name := current.f.Name()
344	data, err := os.ReadFile(name)
345	if err != nil {
346		return nil, fmt.Errorf("failed to read from file: %v", err)
347	}
348	pf, err := Parse(name, data)
349	if err != nil {
350		return nil, fmt.Errorf("failed to parse: %v", err)
351	}
352	return pf, nil
353}
354
355// ReadFile reads the counters and stack counters from the given file.
356// This is the implementation of x/telemetry/counter/countertest.ReadFile.
357func ReadFile(name string) (counters, stackCounters map[string]uint64, _ error) {
358	// TODO: Document the format of the stackCounters names.
359
360	data, err := ReadMapped(name)
361	if err != nil {
362		return nil, nil, fmt.Errorf("failed to read from file: %v", err)
363	}
364	pf, err := Parse(name, data)
365	if err != nil {
366		return nil, nil, fmt.Errorf("failed to parse: %v", err)
367	}
368	counters = make(map[string]uint64)
369	stackCounters = make(map[string]uint64)
370	for k, v := range pf.Count {
371		if IsStackCounter(k) {
372			stackCounters[DecodeStack(k)] = v
373		} else {
374			counters[k] = v
375		}
376	}
377	return counters, stackCounters, nil
378}
379
380// ReadMapped reads the contents of the given file by memory mapping.
381//
382// This avoids file synchronization issues.
383func ReadMapped(name string) ([]byte, error) {
384	f, err := os.OpenFile(name, os.O_RDWR, 0666)
385	if err != nil {
386		return nil, err
387	}
388	defer f.Close()
389	fi, err := f.Stat()
390	if err != nil {
391		return nil, err
392	}
393	mapping, err := memmap(f)
394	if err != nil {
395		return nil, err
396	}
397	data := make([]byte, fi.Size())
398	copy(data, mapping.Data)
399	munmap(mapping)
400	return data, nil
401}
402