1// Copyright 2023 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package inlheur
6
7import (
8	"cmd/compile/internal/ir"
9	"cmd/compile/internal/pgoir"
10	"cmd/compile/internal/typecheck"
11	"fmt"
12	"os"
13	"strings"
14)
15
16type callSiteAnalyzer struct {
17	fn *ir.Func
18	*nameFinder
19}
20
21type callSiteTableBuilder struct {
22	fn *ir.Func
23	*nameFinder
24	cstab    CallSiteTab
25	ptab     map[ir.Node]pstate
26	nstack   []ir.Node
27	loopNest int
28	isInit   bool
29}
30
31func makeCallSiteAnalyzer(fn *ir.Func) *callSiteAnalyzer {
32	return &callSiteAnalyzer{
33		fn:         fn,
34		nameFinder: newNameFinder(fn),
35	}
36}
37
38func makeCallSiteTableBuilder(fn *ir.Func, cstab CallSiteTab, ptab map[ir.Node]pstate, loopNestingLevel int, nf *nameFinder) *callSiteTableBuilder {
39	isInit := fn.IsPackageInit() || strings.HasPrefix(fn.Sym().Name, "init.")
40	return &callSiteTableBuilder{
41		fn:         fn,
42		cstab:      cstab,
43		ptab:       ptab,
44		isInit:     isInit,
45		loopNest:   loopNestingLevel,
46		nstack:     []ir.Node{fn},
47		nameFinder: nf,
48	}
49}
50
51// computeCallSiteTable builds and returns a table of call sites for
52// the specified region in function fn. A region here corresponds to a
53// specific subtree within the AST for a function. The main intended
54// use cases are for 'region' to be either A) an entire function body,
55// or B) an inlined call expression.
56func computeCallSiteTable(fn *ir.Func, region ir.Nodes, cstab CallSiteTab, ptab map[ir.Node]pstate, loopNestingLevel int, nf *nameFinder) CallSiteTab {
57	cstb := makeCallSiteTableBuilder(fn, cstab, ptab, loopNestingLevel, nf)
58	var doNode func(ir.Node) bool
59	doNode = func(n ir.Node) bool {
60		cstb.nodeVisitPre(n)
61		ir.DoChildren(n, doNode)
62		cstb.nodeVisitPost(n)
63		return false
64	}
65	for _, n := range region {
66		doNode(n)
67	}
68	return cstb.cstab
69}
70
71func (cstb *callSiteTableBuilder) flagsForNode(call *ir.CallExpr) CSPropBits {
72	var r CSPropBits
73
74	if debugTrace&debugTraceCalls != 0 {
75		fmt.Fprintf(os.Stderr, "=-= analyzing call at %s\n",
76			fmtFullPos(call.Pos()))
77	}
78
79	// Set a bit if this call is within a loop.
80	if cstb.loopNest > 0 {
81		r |= CallSiteInLoop
82	}
83
84	// Set a bit if the call is within an init function (either
85	// compiler-generated or user-written).
86	if cstb.isInit {
87		r |= CallSiteInInitFunc
88	}
89
90	// Decide whether to apply the panic path heuristic. Hack: don't
91	// apply this heuristic in the function "main.main" (mostly just
92	// to avoid annoying users).
93	if !isMainMain(cstb.fn) {
94		r = cstb.determinePanicPathBits(call, r)
95	}
96
97	return r
98}
99
100// determinePanicPathBits updates the CallSiteOnPanicPath bit within
101// "r" if we think this call is on an unconditional path to
102// panic/exit. Do this by walking back up the node stack to see if we
103// can find either A) an enclosing panic, or B) a statement node that
104// we've determined leads to a panic/exit.
105func (cstb *callSiteTableBuilder) determinePanicPathBits(call ir.Node, r CSPropBits) CSPropBits {
106	cstb.nstack = append(cstb.nstack, call)
107	defer func() {
108		cstb.nstack = cstb.nstack[:len(cstb.nstack)-1]
109	}()
110
111	for ri := range cstb.nstack[:len(cstb.nstack)-1] {
112		i := len(cstb.nstack) - ri - 1
113		n := cstb.nstack[i]
114		_, isCallExpr := n.(*ir.CallExpr)
115		_, isStmt := n.(ir.Stmt)
116		if isCallExpr {
117			isStmt = false
118		}
119
120		if debugTrace&debugTraceCalls != 0 {
121			ps, inps := cstb.ptab[n]
122			fmt.Fprintf(os.Stderr, "=-= callpar %d op=%s ps=%s inptab=%v stmt=%v\n", i, n.Op().String(), ps.String(), inps, isStmt)
123		}
124
125		if n.Op() == ir.OPANIC {
126			r |= CallSiteOnPanicPath
127			break
128		}
129		if v, ok := cstb.ptab[n]; ok {
130			if v == psCallsPanic {
131				r |= CallSiteOnPanicPath
132				break
133			}
134			if isStmt {
135				break
136			}
137		}
138	}
139	return r
140}
141
142// propsForArg returns property bits for a given call argument expression arg.
143func (cstb *callSiteTableBuilder) propsForArg(arg ir.Node) ActualExprPropBits {
144	if cval := cstb.constValue(arg); cval != nil {
145		return ActualExprConstant
146	}
147	if cstb.isConcreteConvIface(arg) {
148		return ActualExprIsConcreteConvIface
149	}
150	fname := cstb.funcName(arg)
151	if fname != nil {
152		if fn := fname.Func; fn != nil && typecheck.HaveInlineBody(fn) {
153			return ActualExprIsInlinableFunc
154		}
155		return ActualExprIsFunc
156	}
157	return 0
158}
159
160// argPropsForCall returns a slice of argument properties for the
161// expressions being passed to the callee in the specific call
162// expression; these will be stored in the CallSite object for a given
163// call and then consulted when scoring. If no arg has any interesting
164// properties we try to save some space and return a nil slice.
165func (cstb *callSiteTableBuilder) argPropsForCall(ce *ir.CallExpr) []ActualExprPropBits {
166	rv := make([]ActualExprPropBits, len(ce.Args))
167	somethingInteresting := false
168	for idx := range ce.Args {
169		argProp := cstb.propsForArg(ce.Args[idx])
170		somethingInteresting = somethingInteresting || (argProp != 0)
171		rv[idx] = argProp
172	}
173	if !somethingInteresting {
174		return nil
175	}
176	return rv
177}
178
179func (cstb *callSiteTableBuilder) addCallSite(callee *ir.Func, call *ir.CallExpr) {
180	flags := cstb.flagsForNode(call)
181	argProps := cstb.argPropsForCall(call)
182	if debugTrace&debugTraceCalls != 0 {
183		fmt.Fprintf(os.Stderr, "=-= props %+v for call %v\n", argProps, call)
184	}
185	// FIXME: maybe bulk-allocate these?
186	cs := &CallSite{
187		Call:     call,
188		Callee:   callee,
189		Assign:   cstb.containingAssignment(call),
190		ArgProps: argProps,
191		Flags:    flags,
192		ID:       uint(len(cstb.cstab)),
193	}
194	if _, ok := cstb.cstab[call]; ok {
195		fmt.Fprintf(os.Stderr, "*** cstab duplicate entry at: %s\n",
196			fmtFullPos(call.Pos()))
197		fmt.Fprintf(os.Stderr, "*** call: %+v\n", call)
198		panic("bad")
199	}
200	// Set initial score for callsite to the cost computed
201	// by CanInline; this score will be refined later based
202	// on heuristics.
203	cs.Score = int(callee.Inl.Cost)
204
205	if cstb.cstab == nil {
206		cstb.cstab = make(CallSiteTab)
207	}
208	cstb.cstab[call] = cs
209	if debugTrace&debugTraceCalls != 0 {
210		fmt.Fprintf(os.Stderr, "=-= added callsite: caller=%v callee=%v n=%s\n",
211			cstb.fn, callee, fmtFullPos(call.Pos()))
212	}
213}
214
215func (cstb *callSiteTableBuilder) nodeVisitPre(n ir.Node) {
216	switch n.Op() {
217	case ir.ORANGE, ir.OFOR:
218		if !hasTopLevelLoopBodyReturnOrBreak(loopBody(n)) {
219			cstb.loopNest++
220		}
221	case ir.OCALLFUNC:
222		ce := n.(*ir.CallExpr)
223		callee := pgoir.DirectCallee(ce.Fun)
224		if callee != nil && callee.Inl != nil {
225			cstb.addCallSite(callee, ce)
226		}
227	}
228	cstb.nstack = append(cstb.nstack, n)
229}
230
231func (cstb *callSiteTableBuilder) nodeVisitPost(n ir.Node) {
232	cstb.nstack = cstb.nstack[:len(cstb.nstack)-1]
233	switch n.Op() {
234	case ir.ORANGE, ir.OFOR:
235		if !hasTopLevelLoopBodyReturnOrBreak(loopBody(n)) {
236			cstb.loopNest--
237		}
238	}
239}
240
241func loopBody(n ir.Node) ir.Nodes {
242	if forst, ok := n.(*ir.ForStmt); ok {
243		return forst.Body
244	}
245	if rst, ok := n.(*ir.RangeStmt); ok {
246		return rst.Body
247	}
248	return nil
249}
250
251// hasTopLevelLoopBodyReturnOrBreak examines the body of a "for" or
252// "range" loop to try to verify that it is a real loop, as opposed to
253// a construct that is syntactically loopy but doesn't actually iterate
254// multiple times, like:
255//
256//	for {
257//	  blah()
258//	  return 1
259//	}
260//
261// [Remark: the pattern above crops up quite a bit in the source code
262// for the compiler itself, e.g. the auto-generated rewrite code]
263//
264// Note that we don't look for GOTO statements here, so it's possible
265// we'll get the wrong result for a loop with complicated control
266// jumps via gotos.
267func hasTopLevelLoopBodyReturnOrBreak(loopBody ir.Nodes) bool {
268	for _, n := range loopBody {
269		if n.Op() == ir.ORETURN || n.Op() == ir.OBREAK {
270			return true
271		}
272	}
273	return false
274}
275
276// containingAssignment returns the top-level assignment statement
277// for a statement level function call "n". Examples:
278//
279//	x := foo()
280//	x, y := bar(z, baz())
281//	if blah() { ...
282//
283// Here the top-level assignment statement for the foo() call is the
284// statement assigning to "x"; the top-level assignment for "bar()"
285// call is the assignment to x,y. For the baz() and blah() calls,
286// there is no top level assignment statement.
287//
288// The unstated goal here is that we want to use the containing
289// assignment to establish a connection between a given call and the
290// variables to which its results/returns are being assigned.
291//
292// Note that for the "bar" command above, the front end sometimes
293// decomposes this into two assignments, the first one assigning the
294// call to a pair of auto-temps, then the second one assigning the
295// auto-temps to the user-visible vars. This helper will return the
296// second (outer) of these two.
297func (cstb *callSiteTableBuilder) containingAssignment(n ir.Node) ir.Node {
298	parent := cstb.nstack[len(cstb.nstack)-1]
299
300	// assignsOnlyAutoTemps returns TRUE of the specified OAS2FUNC
301	// node assigns only auto-temps.
302	assignsOnlyAutoTemps := func(x ir.Node) bool {
303		alst := x.(*ir.AssignListStmt)
304		oa2init := alst.Init()
305		if len(oa2init) == 0 {
306			return false
307		}
308		for _, v := range oa2init {
309			d := v.(*ir.Decl)
310			if !ir.IsAutoTmp(d.X) {
311				return false
312			}
313		}
314		return true
315	}
316
317	// Simple case: x := foo()
318	if parent.Op() == ir.OAS {
319		return parent
320	}
321
322	// Multi-return case: x, y := bar()
323	if parent.Op() == ir.OAS2FUNC {
324		// Hack city: if the result vars are auto-temps, try looking
325		// for an outer assignment in the tree. The code shape we're
326		// looking for here is:
327		//
328		// OAS1({x,y},OCONVNOP(OAS2FUNC({auto1,auto2},OCALLFUNC(bar))))
329		//
330		if assignsOnlyAutoTemps(parent) {
331			par2 := cstb.nstack[len(cstb.nstack)-2]
332			if par2.Op() == ir.OAS2 {
333				return par2
334			}
335			if par2.Op() == ir.OCONVNOP {
336				par3 := cstb.nstack[len(cstb.nstack)-3]
337				if par3.Op() == ir.OAS2 {
338					return par3
339				}
340			}
341		}
342	}
343
344	return nil
345}
346
347// UpdateCallsiteTable handles updating of callerfn's call site table
348// after an inlined has been carried out, e.g. the call at 'n' as been
349// turned into the inlined call expression 'ic' within function
350// callerfn. The chief thing of interest here is to make sure that any
351// call nodes within 'ic' are added to the call site table for
352// 'callerfn' and scored appropriately.
353func UpdateCallsiteTable(callerfn *ir.Func, n *ir.CallExpr, ic *ir.InlinedCallExpr) {
354	enableDebugTraceIfEnv()
355	defer disableDebugTrace()
356
357	funcInlHeur, ok := fpmap[callerfn]
358	if !ok {
359		// This can happen for compiler-generated wrappers.
360		if debugTrace&debugTraceCalls != 0 {
361			fmt.Fprintf(os.Stderr, "=-= early exit, no entry for caller fn %v\n", callerfn)
362		}
363		return
364	}
365
366	if debugTrace&debugTraceCalls != 0 {
367		fmt.Fprintf(os.Stderr, "=-= UpdateCallsiteTable(caller=%v, cs=%s)\n",
368			callerfn, fmtFullPos(n.Pos()))
369	}
370
371	// Mark the call in question as inlined.
372	oldcs, ok := funcInlHeur.cstab[n]
373	if !ok {
374		// This can happen for compiler-generated wrappers.
375		return
376	}
377	oldcs.aux |= csAuxInlined
378
379	if debugTrace&debugTraceCalls != 0 {
380		fmt.Fprintf(os.Stderr, "=-= marked as inlined: callee=%v %s\n",
381			oldcs.Callee, EncodeCallSiteKey(oldcs))
382	}
383
384	// Walk the inlined call region to collect new callsites.
385	var icp pstate
386	if oldcs.Flags&CallSiteOnPanicPath != 0 {
387		icp = psCallsPanic
388	}
389	var loopNestLevel int
390	if oldcs.Flags&CallSiteInLoop != 0 {
391		loopNestLevel = 1
392	}
393	ptab := map[ir.Node]pstate{ic: icp}
394	nf := newNameFinder(nil)
395	icstab := computeCallSiteTable(callerfn, ic.Body, nil, ptab, loopNestLevel, nf)
396
397	// Record parent callsite. This is primarily for debug output.
398	for _, cs := range icstab {
399		cs.parent = oldcs
400	}
401
402	// Score the calls in the inlined body. Note the setting of
403	// "doCallResults" to false here: at the moment there isn't any
404	// easy way to localize or region-ize the work done by
405	// "rescoreBasedOnCallResultUses", which currently does a walk
406	// over the entire function to look for uses of a given set of
407	// results. Similarly we're passing nil to makeCallSiteAnalyzer,
408	// so as to run name finding without the use of static value &
409	// friends.
410	csa := makeCallSiteAnalyzer(nil)
411	const doCallResults = false
412	csa.scoreCallsRegion(callerfn, ic.Body, icstab, doCallResults, ic)
413}
414