1// Copyright 2023 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package inlheur
6
7import (
8	"cmd/compile/internal/ir"
9	"fmt"
10	"os"
11)
12
13// This file contains code to re-score callsites based on how the
14// results of the call were used.  Example:
15//
16//    func foo() {
17//       x, fptr := bar()
18//       switch x {
19//         case 10: fptr = baz()
20//         default: blix()
21//       }
22//       fptr(100)
23//     }
24//
25// The initial scoring pass will assign a score to "bar()" based on
26// various criteria, however once the first pass of scoring is done,
27// we look at the flags on the result from bar, and check to see
28// how those results are used. If bar() always returns the same constant
29// for its first result, and if the variable receiving that result
30// isn't redefined, and if that variable feeds into an if/switch
31// condition, then we will try to adjust the score for "bar" (on the
32// theory that if we inlined, we can constant fold / deadcode).
33
34type resultPropAndCS struct {
35	defcs *CallSite
36	props ResultPropBits
37}
38
39type resultUseAnalyzer struct {
40	resultNameTab map[*ir.Name]resultPropAndCS
41	fn            *ir.Func
42	cstab         CallSiteTab
43	*condLevelTracker
44}
45
46// rescoreBasedOnCallResultUses examines how call results are used,
47// and tries to update the scores of calls based on how their results
48// are used in the function.
49func (csa *callSiteAnalyzer) rescoreBasedOnCallResultUses(fn *ir.Func, resultNameTab map[*ir.Name]resultPropAndCS, cstab CallSiteTab) {
50	enableDebugTraceIfEnv()
51	rua := &resultUseAnalyzer{
52		resultNameTab:    resultNameTab,
53		fn:               fn,
54		cstab:            cstab,
55		condLevelTracker: new(condLevelTracker),
56	}
57	var doNode func(ir.Node) bool
58	doNode = func(n ir.Node) bool {
59		rua.nodeVisitPre(n)
60		ir.DoChildren(n, doNode)
61		rua.nodeVisitPost(n)
62		return false
63	}
64	doNode(fn)
65	disableDebugTrace()
66}
67
68func (csa *callSiteAnalyzer) examineCallResults(cs *CallSite, resultNameTab map[*ir.Name]resultPropAndCS) map[*ir.Name]resultPropAndCS {
69	if debugTrace&debugTraceScoring != 0 {
70		fmt.Fprintf(os.Stderr, "=-= examining call results for %q\n",
71			EncodeCallSiteKey(cs))
72	}
73
74	// Invoke a helper to pick out the specific ir.Name's the results
75	// from this call are assigned into, e.g. "x, y := fooBar()". If
76	// the call is not part of an assignment statement, or if the
77	// variables in question are not newly defined, then we'll receive
78	// an empty list here.
79	//
80	names, autoTemps, props := namesDefined(cs)
81	if len(names) == 0 {
82		return resultNameTab
83	}
84
85	if debugTrace&debugTraceScoring != 0 {
86		fmt.Fprintf(os.Stderr, "=-= %d names defined\n", len(names))
87	}
88
89	// For each returned value, if the value has interesting
90	// properties (ex: always returns the same constant), and the name
91	// in question is never redefined, then make an entry in the
92	// result table for it.
93	const interesting = (ResultIsConcreteTypeConvertedToInterface |
94		ResultAlwaysSameConstant | ResultAlwaysSameInlinableFunc | ResultAlwaysSameFunc)
95	for idx, n := range names {
96		rprop := props.ResultFlags[idx]
97
98		if debugTrace&debugTraceScoring != 0 {
99			fmt.Fprintf(os.Stderr, "=-= props for ret %d %q: %s\n",
100				idx, n.Sym().Name, rprop.String())
101		}
102
103		if rprop&interesting == 0 {
104			continue
105		}
106		if csa.nameFinder.reassigned(n) {
107			continue
108		}
109		if resultNameTab == nil {
110			resultNameTab = make(map[*ir.Name]resultPropAndCS)
111		} else if _, ok := resultNameTab[n]; ok {
112			panic("should never happen")
113		}
114		entry := resultPropAndCS{
115			defcs: cs,
116			props: rprop,
117		}
118		resultNameTab[n] = entry
119		if autoTemps[idx] != nil {
120			resultNameTab[autoTemps[idx]] = entry
121		}
122		if debugTrace&debugTraceScoring != 0 {
123			fmt.Fprintf(os.Stderr, "=-= add resultNameTab table entry n=%v autotemp=%v props=%s\n", n, autoTemps[idx], rprop.String())
124		}
125	}
126	return resultNameTab
127}
128
129// namesDefined returns a list of ir.Name's corresponding to locals
130// that receive the results from the call at site 'cs', plus the
131// properties object for the called function. If a given result
132// isn't cleanly assigned to a newly defined local, the
133// slot for that result in the returned list will be nil. Example:
134//
135//	call                             returned name list
136//
137//	x := foo()                       [ x ]
138//	z, y := bar()                    [ nil, nil ]
139//	_, q := baz()                    [ nil, q ]
140//
141// In the case of a multi-return call, such as "x, y := foo()",
142// the pattern we see from the front end will be a call op
143// assigning to auto-temps, and then an assignment of the auto-temps
144// to the user-level variables. In such cases we return
145// first the user-level variable (in the first func result)
146// and then the auto-temp name in the second result.
147func namesDefined(cs *CallSite) ([]*ir.Name, []*ir.Name, *FuncProps) {
148	// If this call doesn't feed into an assignment (and of course not
149	// all calls do), then we don't have anything to work with here.
150	if cs.Assign == nil {
151		return nil, nil, nil
152	}
153	funcInlHeur, ok := fpmap[cs.Callee]
154	if !ok {
155		// TODO: add an assert/panic here.
156		return nil, nil, nil
157	}
158	if len(funcInlHeur.props.ResultFlags) == 0 {
159		return nil, nil, nil
160	}
161
162	// Single return case.
163	if len(funcInlHeur.props.ResultFlags) == 1 {
164		asgn, ok := cs.Assign.(*ir.AssignStmt)
165		if !ok {
166			return nil, nil, nil
167		}
168		// locate name being assigned
169		aname, ok := asgn.X.(*ir.Name)
170		if !ok {
171			return nil, nil, nil
172		}
173		return []*ir.Name{aname}, []*ir.Name{nil}, funcInlHeur.props
174	}
175
176	// Multi-return case
177	asgn, ok := cs.Assign.(*ir.AssignListStmt)
178	if !ok || !asgn.Def {
179		return nil, nil, nil
180	}
181	userVars := make([]*ir.Name, len(funcInlHeur.props.ResultFlags))
182	autoTemps := make([]*ir.Name, len(funcInlHeur.props.ResultFlags))
183	for idx, x := range asgn.Lhs {
184		if n, ok := x.(*ir.Name); ok {
185			userVars[idx] = n
186			r := asgn.Rhs[idx]
187			if r.Op() == ir.OCONVNOP {
188				r = r.(*ir.ConvExpr).X
189			}
190			if ir.IsAutoTmp(r) {
191				autoTemps[idx] = r.(*ir.Name)
192			}
193			if debugTrace&debugTraceScoring != 0 {
194				fmt.Fprintf(os.Stderr, "=-= multi-ret namedef uv=%v at=%v\n",
195					x, autoTemps[idx])
196			}
197		} else {
198			return nil, nil, nil
199		}
200	}
201	return userVars, autoTemps, funcInlHeur.props
202}
203
204func (rua *resultUseAnalyzer) nodeVisitPost(n ir.Node) {
205	rua.condLevelTracker.post(n)
206}
207
208func (rua *resultUseAnalyzer) nodeVisitPre(n ir.Node) {
209	rua.condLevelTracker.pre(n)
210	switch n.Op() {
211	case ir.OCALLINTER:
212		if debugTrace&debugTraceScoring != 0 {
213			fmt.Fprintf(os.Stderr, "=-= rescore examine iface call %v:\n", n)
214		}
215		rua.callTargetCheckResults(n)
216	case ir.OCALLFUNC:
217		if debugTrace&debugTraceScoring != 0 {
218			fmt.Fprintf(os.Stderr, "=-= rescore examine call %v:\n", n)
219		}
220		rua.callTargetCheckResults(n)
221	case ir.OIF:
222		ifst := n.(*ir.IfStmt)
223		rua.foldCheckResults(ifst.Cond)
224	case ir.OSWITCH:
225		swst := n.(*ir.SwitchStmt)
226		if swst.Tag != nil {
227			rua.foldCheckResults(swst.Tag)
228		}
229
230	}
231}
232
233// callTargetCheckResults examines a given call to see whether the
234// callee expression is potentially an inlinable function returned
235// from a potentially inlinable call. Examples:
236//
237//	Scenario 1: named intermediate
238//
239//	   fn1 := foo()         conc := bar()
240//	   fn1("blah")          conc.MyMethod()
241//
242//	Scenario 2: returned func or concrete object feeds directly to call
243//
244//	   foo()("blah")        bar().MyMethod()
245//
246// In the second case although at the source level the result of the
247// direct call feeds right into the method call or indirect call,
248// we're relying on the front end having inserted an auto-temp to
249// capture the value.
250func (rua *resultUseAnalyzer) callTargetCheckResults(call ir.Node) {
251	ce := call.(*ir.CallExpr)
252	rname := rua.getCallResultName(ce)
253	if rname == nil {
254		return
255	}
256	if debugTrace&debugTraceScoring != 0 {
257		fmt.Fprintf(os.Stderr, "=-= staticvalue returns %v:\n",
258			rname)
259	}
260	if rname.Class != ir.PAUTO {
261		return
262	}
263	switch call.Op() {
264	case ir.OCALLINTER:
265		if debugTrace&debugTraceScoring != 0 {
266			fmt.Fprintf(os.Stderr, "=-= in %s checking %v for cci prop:\n",
267				rua.fn.Sym().Name, rname)
268		}
269		if cs := rua.returnHasProp(rname, ResultIsConcreteTypeConvertedToInterface); cs != nil {
270
271			adj := returnFeedsConcreteToInterfaceCallAdj
272			cs.Score, cs.ScoreMask = adjustScore(adj, cs.Score, cs.ScoreMask)
273		}
274	case ir.OCALLFUNC:
275		if debugTrace&debugTraceScoring != 0 {
276			fmt.Fprintf(os.Stderr, "=-= in %s checking %v for samefunc props:\n",
277				rua.fn.Sym().Name, rname)
278			v, ok := rua.resultNameTab[rname]
279			if !ok {
280				fmt.Fprintf(os.Stderr, "=-= no entry for %v in rt\n", rname)
281			} else {
282				fmt.Fprintf(os.Stderr, "=-= props for %v: %q\n", rname, v.props.String())
283			}
284		}
285		if cs := rua.returnHasProp(rname, ResultAlwaysSameInlinableFunc); cs != nil {
286			adj := returnFeedsInlinableFuncToIndCallAdj
287			cs.Score, cs.ScoreMask = adjustScore(adj, cs.Score, cs.ScoreMask)
288		} else if cs := rua.returnHasProp(rname, ResultAlwaysSameFunc); cs != nil {
289			adj := returnFeedsFuncToIndCallAdj
290			cs.Score, cs.ScoreMask = adjustScore(adj, cs.Score, cs.ScoreMask)
291
292		}
293	}
294}
295
296// foldCheckResults examines the specified if/switch condition 'cond'
297// to see if it refers to locals defined by a (potentially inlinable)
298// function call at call site C, and if so, whether 'cond' contains
299// only combinations of simple references to all of the names in
300// 'names' with selected constants + operators. If these criteria are
301// met, then we adjust the score for call site C to reflect the
302// fact that inlining will enable deadcode and/or constant propagation.
303// Note: for this heuristic to kick in, the names in question have to
304// be all from the same callsite. Examples:
305//
306//	  q, r := baz()	    x, y := foo()
307//	  switch q+r {		a, b, c := bar()
308//		...			    if x && y && a && b && c {
309//	  }					   ...
310//					    }
311//
312// For the call to "baz" above we apply a score adjustment, but not
313// for the calls to "foo" or "bar".
314func (rua *resultUseAnalyzer) foldCheckResults(cond ir.Node) {
315	namesUsed := collectNamesUsed(cond)
316	if len(namesUsed) == 0 {
317		return
318	}
319	var cs *CallSite
320	for _, n := range namesUsed {
321		rpcs, found := rua.resultNameTab[n]
322		if !found {
323			return
324		}
325		if cs != nil && rpcs.defcs != cs {
326			return
327		}
328		cs = rpcs.defcs
329		if rpcs.props&ResultAlwaysSameConstant == 0 {
330			return
331		}
332	}
333	if debugTrace&debugTraceScoring != 0 {
334		nls := func(nl []*ir.Name) string {
335			r := ""
336			for _, n := range nl {
337				r += " " + n.Sym().Name
338			}
339			return r
340		}
341		fmt.Fprintf(os.Stderr, "=-= calling ShouldFoldIfNameConstant on names={%s} cond=%v\n", nls(namesUsed), cond)
342	}
343
344	if !ShouldFoldIfNameConstant(cond, namesUsed) {
345		return
346	}
347	adj := returnFeedsConstToIfAdj
348	cs.Score, cs.ScoreMask = adjustScore(adj, cs.Score, cs.ScoreMask)
349}
350
351func collectNamesUsed(expr ir.Node) []*ir.Name {
352	res := []*ir.Name{}
353	ir.Visit(expr, func(n ir.Node) {
354		if n.Op() != ir.ONAME {
355			return
356		}
357		nn := n.(*ir.Name)
358		if nn.Class != ir.PAUTO {
359			return
360		}
361		res = append(res, nn)
362	})
363	return res
364}
365
366func (rua *resultUseAnalyzer) returnHasProp(name *ir.Name, prop ResultPropBits) *CallSite {
367	v, ok := rua.resultNameTab[name]
368	if !ok {
369		return nil
370	}
371	if v.props&prop == 0 {
372		return nil
373	}
374	return v.defcs
375}
376
377func (rua *resultUseAnalyzer) getCallResultName(ce *ir.CallExpr) *ir.Name {
378	var callTarg ir.Node
379	if sel, ok := ce.Fun.(*ir.SelectorExpr); ok {
380		// method call
381		callTarg = sel.X
382	} else if ctarg, ok := ce.Fun.(*ir.Name); ok {
383		// regular call
384		callTarg = ctarg
385	} else {
386		return nil
387	}
388	r := ir.StaticValue(callTarg)
389	if debugTrace&debugTraceScoring != 0 {
390		fmt.Fprintf(os.Stderr, "=-= staticname on %v returns %v:\n",
391			callTarg, r)
392	}
393	if r.Op() == ir.OCALLFUNC {
394		// This corresponds to the "x := foo()" case; here
395		// ir.StaticValue has brought us all the way back to
396		// the call expression itself. We need to back off to
397		// the name defined by the call; do this by looking up
398		// the callsite.
399		ce := r.(*ir.CallExpr)
400		cs, ok := rua.cstab[ce]
401		if !ok {
402			return nil
403		}
404		names, _, _ := namesDefined(cs)
405		if len(names) == 0 {
406			return nil
407		}
408		return names[0]
409	} else if r.Op() == ir.ONAME {
410		return r.(*ir.Name)
411	}
412	return nil
413}
414