1// Copyright 2023 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package testkit
6
7import (
8	"bytes"
9	"encoding/binary"
10	"fmt"
11	"os"
12	"regexp"
13	"strings"
14
15	"internal/trace"
16	"internal/trace/event"
17	"internal/trace/event/go122"
18	"internal/trace/raw"
19	"internal/trace/version"
20	"internal/txtar"
21)
22
23func Main(f func(*Trace)) {
24	// Create an output file.
25	out, err := os.Create(os.Args[1])
26	if err != nil {
27		panic(err.Error())
28	}
29	defer out.Close()
30
31	// Create a new trace.
32	trace := NewTrace()
33
34	// Call the generator.
35	f(trace)
36
37	// Write out the generator's state.
38	if _, err := out.Write(trace.Generate()); err != nil {
39		panic(err.Error())
40	}
41}
42
43// Trace represents an execution trace for testing.
44//
45// It does a little bit of work to ensure that the produced trace is valid,
46// just for convenience. It mainly tracks batches and batch sizes (so they're
47// trivially correct), tracks strings and stacks, and makes sure emitted string
48// and stack batches are valid. That last part can be controlled by a few options.
49//
50// Otherwise, it performs no validation on the trace at all.
51type Trace struct {
52	// Trace data state.
53	ver             version.Version
54	names           map[string]event.Type
55	specs           []event.Spec
56	events          []raw.Event
57	gens            []*Generation
58	validTimestamps bool
59
60	// Expectation state.
61	bad      bool
62	badMatch *regexp.Regexp
63}
64
65// NewTrace creates a new trace.
66func NewTrace() *Trace {
67	ver := version.Go122
68	return &Trace{
69		names:           event.Names(ver.Specs()),
70		specs:           ver.Specs(),
71		validTimestamps: true,
72	}
73}
74
75// ExpectFailure writes down that the trace should be broken. The caller
76// must provide a pattern matching the expected error produced by the parser.
77func (t *Trace) ExpectFailure(pattern string) {
78	t.bad = true
79	t.badMatch = regexp.MustCompile(pattern)
80}
81
82// ExpectSuccess writes down that the trace should successfully parse.
83func (t *Trace) ExpectSuccess() {
84	t.bad = false
85}
86
87// RawEvent emits an event into the trace. name must correspond to one
88// of the names in Specs() result for the version that was passed to
89// this trace.
90func (t *Trace) RawEvent(typ event.Type, data []byte, args ...uint64) {
91	t.events = append(t.events, t.createEvent(typ, data, args...))
92}
93
94// DisableTimestamps makes the timestamps for all events generated after
95// this call zero. Raw events are exempted from this because the caller
96// has to pass their own timestamp into those events anyway.
97func (t *Trace) DisableTimestamps() {
98	t.validTimestamps = false
99}
100
101// Generation creates a new trace generation.
102//
103// This provides more structure than Event to allow for more easily
104// creating complex traces that are mostly or completely correct.
105func (t *Trace) Generation(gen uint64) *Generation {
106	g := &Generation{
107		trace:   t,
108		gen:     gen,
109		strings: make(map[string]uint64),
110		stacks:  make(map[stack]uint64),
111	}
112	t.gens = append(t.gens, g)
113	return g
114}
115
116// Generate creates a test file for the trace.
117func (t *Trace) Generate() []byte {
118	// Trace file contents.
119	var buf bytes.Buffer
120	tw, err := raw.NewTextWriter(&buf, version.Go122)
121	if err != nil {
122		panic(err.Error())
123	}
124
125	// Write raw top-level events.
126	for _, e := range t.events {
127		tw.WriteEvent(e)
128	}
129
130	// Write generations.
131	for _, g := range t.gens {
132		g.writeEventsTo(tw)
133	}
134
135	// Expectation file contents.
136	expect := []byte("SUCCESS\n")
137	if t.bad {
138		expect = []byte(fmt.Sprintf("FAILURE %q\n", t.badMatch))
139	}
140
141	// Create the test file's contents.
142	return txtar.Format(&txtar.Archive{
143		Files: []txtar.File{
144			{Name: "expect", Data: expect},
145			{Name: "trace", Data: buf.Bytes()},
146		},
147	})
148}
149
150func (t *Trace) createEvent(ev event.Type, data []byte, args ...uint64) raw.Event {
151	spec := t.specs[ev]
152	if ev != go122.EvStack {
153		if arity := len(spec.Args); len(args) != arity {
154			panic(fmt.Sprintf("expected %d args for %s, got %d", arity, spec.Name, len(args)))
155		}
156	}
157	return raw.Event{
158		Version: version.Go122,
159		Ev:      ev,
160		Args:    args,
161		Data:    data,
162	}
163}
164
165type stack struct {
166	stk [32]trace.StackFrame
167	len int
168}
169
170var (
171	NoString = ""
172	NoStack  = []trace.StackFrame{}
173)
174
175// Generation represents a single generation in the trace.
176type Generation struct {
177	trace   *Trace
178	gen     uint64
179	batches []*Batch
180	strings map[string]uint64
181	stacks  map[stack]uint64
182
183	// Options applied when Trace.Generate is called.
184	ignoreStringBatchSizeLimit bool
185	ignoreStackBatchSizeLimit  bool
186}
187
188// Batch starts a new event batch in the trace data.
189//
190// This is convenience function for generating correct batches.
191func (g *Generation) Batch(thread trace.ThreadID, time Time) *Batch {
192	if !g.trace.validTimestamps {
193		time = 0
194	}
195	b := &Batch{
196		gen:       g,
197		thread:    thread,
198		timestamp: time,
199	}
200	g.batches = append(g.batches, b)
201	return b
202}
203
204// String registers a string with the trace.
205//
206// This is a convenience function for easily adding correct
207// strings to traces.
208func (g *Generation) String(s string) uint64 {
209	if len(s) == 0 {
210		return 0
211	}
212	if id, ok := g.strings[s]; ok {
213		return id
214	}
215	id := uint64(len(g.strings) + 1)
216	g.strings[s] = id
217	return id
218}
219
220// Stack registers a stack with the trace.
221//
222// This is a convenience function for easily adding correct
223// stacks to traces.
224func (g *Generation) Stack(stk []trace.StackFrame) uint64 {
225	if len(stk) == 0 {
226		return 0
227	}
228	if len(stk) > 32 {
229		panic("stack too big for test")
230	}
231	var stkc stack
232	copy(stkc.stk[:], stk)
233	stkc.len = len(stk)
234	if id, ok := g.stacks[stkc]; ok {
235		return id
236	}
237	id := uint64(len(g.stacks) + 1)
238	g.stacks[stkc] = id
239	return id
240}
241
242// writeEventsTo emits event batches in the generation to tw.
243func (g *Generation) writeEventsTo(tw *raw.TextWriter) {
244	// Write event batches for the generation.
245	for _, b := range g.batches {
246		b.writeEventsTo(tw)
247	}
248
249	// Write frequency.
250	b := g.newStructuralBatch()
251	b.RawEvent(go122.EvFrequency, nil, 15625000)
252	b.writeEventsTo(tw)
253
254	// Write stacks.
255	b = g.newStructuralBatch()
256	b.RawEvent(go122.EvStacks, nil)
257	for stk, id := range g.stacks {
258		stk := stk.stk[:stk.len]
259		args := []uint64{id}
260		for _, f := range stk {
261			args = append(args, f.PC, g.String(f.Func), g.String(f.File), f.Line)
262		}
263		b.RawEvent(go122.EvStack, nil, args...)
264
265		// Flush the batch if necessary.
266		if !g.ignoreStackBatchSizeLimit && b.size > go122.MaxBatchSize/2 {
267			b.writeEventsTo(tw)
268			b = g.newStructuralBatch()
269		}
270	}
271	b.writeEventsTo(tw)
272
273	// Write strings.
274	b = g.newStructuralBatch()
275	b.RawEvent(go122.EvStrings, nil)
276	for s, id := range g.strings {
277		b.RawEvent(go122.EvString, []byte(s), id)
278
279		// Flush the batch if necessary.
280		if !g.ignoreStringBatchSizeLimit && b.size > go122.MaxBatchSize/2 {
281			b.writeEventsTo(tw)
282			b = g.newStructuralBatch()
283		}
284	}
285	b.writeEventsTo(tw)
286}
287
288func (g *Generation) newStructuralBatch() *Batch {
289	return &Batch{gen: g, thread: trace.NoThread}
290}
291
292// Batch represents an event batch.
293type Batch struct {
294	gen       *Generation
295	thread    trace.ThreadID
296	timestamp Time
297	size      uint64
298	events    []raw.Event
299}
300
301// Event emits an event into a batch. name must correspond to one
302// of the names in Specs() result for the version that was passed to
303// this trace. Callers must omit the timestamp delta.
304func (b *Batch) Event(name string, args ...any) {
305	ev, ok := b.gen.trace.names[name]
306	if !ok {
307		panic(fmt.Sprintf("invalid or unknown event %s", name))
308	}
309	var uintArgs []uint64
310	argOff := 0
311	if b.gen.trace.specs[ev].IsTimedEvent {
312		if b.gen.trace.validTimestamps {
313			uintArgs = []uint64{1}
314		} else {
315			uintArgs = []uint64{0}
316		}
317		argOff = 1
318	}
319	spec := b.gen.trace.specs[ev]
320	if arity := len(spec.Args) - argOff; len(args) != arity {
321		panic(fmt.Sprintf("expected %d args for %s, got %d", arity, spec.Name, len(args)))
322	}
323	for i, arg := range args {
324		uintArgs = append(uintArgs, b.uintArgFor(arg, spec.Args[i+argOff]))
325	}
326	b.RawEvent(ev, nil, uintArgs...)
327}
328
329func (b *Batch) uintArgFor(arg any, argSpec string) uint64 {
330	components := strings.SplitN(argSpec, "_", 2)
331	typStr := components[0]
332	if len(components) == 2 {
333		typStr = components[1]
334	}
335	var u uint64
336	switch typStr {
337	case "value":
338		u = arg.(uint64)
339	case "stack":
340		u = b.gen.Stack(arg.([]trace.StackFrame))
341	case "seq":
342		u = uint64(arg.(Seq))
343	case "pstatus":
344		u = uint64(arg.(go122.ProcStatus))
345	case "gstatus":
346		u = uint64(arg.(go122.GoStatus))
347	case "g":
348		u = uint64(arg.(trace.GoID))
349	case "m":
350		u = uint64(arg.(trace.ThreadID))
351	case "p":
352		u = uint64(arg.(trace.ProcID))
353	case "string":
354		u = b.gen.String(arg.(string))
355	case "task":
356		u = uint64(arg.(trace.TaskID))
357	default:
358		panic(fmt.Sprintf("unsupported arg type %q for spec %q", typStr, argSpec))
359	}
360	return u
361}
362
363// RawEvent emits an event into a batch. name must correspond to one
364// of the names in Specs() result for the version that was passed to
365// this trace.
366func (b *Batch) RawEvent(typ event.Type, data []byte, args ...uint64) {
367	ev := b.gen.trace.createEvent(typ, data, args...)
368
369	// Compute the size of the event and add it to the batch.
370	b.size += 1 // One byte for the event header.
371	var buf [binary.MaxVarintLen64]byte
372	for _, arg := range args {
373		b.size += uint64(binary.PutUvarint(buf[:], arg))
374	}
375	if len(data) != 0 {
376		b.size += uint64(binary.PutUvarint(buf[:], uint64(len(data))))
377		b.size += uint64(len(data))
378	}
379
380	// Add the event.
381	b.events = append(b.events, ev)
382}
383
384// writeEventsTo emits events in the batch, including the batch header, to tw.
385func (b *Batch) writeEventsTo(tw *raw.TextWriter) {
386	tw.WriteEvent(raw.Event{
387		Version: version.Go122,
388		Ev:      go122.EvEventBatch,
389		Args:    []uint64{b.gen.gen, uint64(b.thread), uint64(b.timestamp), b.size},
390	})
391	for _, e := range b.events {
392		tw.WriteEvent(e)
393	}
394}
395
396// Seq represents a sequence counter.
397type Seq uint64
398
399// Time represents a low-level trace timestamp (which does not necessarily
400// correspond to nanoseconds, like trace.Time does).
401type Time uint64
402