1// Copyright 2023 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package inlheur
6
7import (
8	"cmd/compile/internal/ir"
9	"fmt"
10	"os"
11)
12
13// ShouldFoldIfNameConstant analyzes expression tree 'e' to see
14// whether it contains only combinations of simple references to all
15// of the names in 'names' with selected constants + operators. The
16// intent is to identify expression that could be folded away to a
17// constant if the value of 'n' were available. Return value is TRUE
18// if 'e' does look foldable given the value of 'n', and given that
19// 'e' actually makes reference to 'n'. Some examples where the type
20// of "n" is int64, type of "s" is string, and type of "p" is *byte:
21//
22//	Simple?		Expr
23//	yes			n<10
24//	yes			n*n-100
25//	yes			(n < 10 || n > 100) && (n >= 12 || n <= 99 || n != 101)
26//	yes			s == "foo"
27//	yes			p == nil
28//	no			n<foo()
29//	no			n<1 || n>m
30//	no			float32(n)<1.0
31//	no			*p == 1
32//	no			1 + 100
33//	no			1 / n
34//	no			1 + unsafe.Sizeof(n)
35//
36// To avoid complexities (e.g. nan, inf) we stay way from folding and
37// floating point or complex operations (integers, bools, and strings
38// only). We also try to be conservative about avoiding any operation
39// that might result in a panic at runtime, e.g. for "n" with type
40// int64:
41//
42//	1<<(n-9) < 100/(n<<9999)
43//
44// we would return FALSE due to the negative shift count and/or
45// potential divide by zero.
46func ShouldFoldIfNameConstant(n ir.Node, names []*ir.Name) bool {
47	cl := makeExprClassifier(names)
48	var doNode func(ir.Node) bool
49	doNode = func(n ir.Node) bool {
50		ir.DoChildren(n, doNode)
51		cl.Visit(n)
52		return false
53	}
54	doNode(n)
55	if cl.getdisp(n) != exprSimple {
56		return false
57	}
58	for _, v := range cl.names {
59		if !v {
60			return false
61		}
62	}
63	return true
64}
65
66// exprClassifier holds intermediate state about nodes within an
67// expression tree being analyzed by ShouldFoldIfNameConstant. Here
68// "name" is the name node passed in, and "disposition" stores the
69// result of classifying a given IR node.
70type exprClassifier struct {
71	names       map[*ir.Name]bool
72	disposition map[ir.Node]disp
73}
74
75type disp int
76
77const (
78	// no info on this expr
79	exprNoInfo disp = iota
80
81	// expr contains only literals
82	exprLiterals
83
84	// expr is legal combination of literals and specified names
85	exprSimple
86)
87
88func (d disp) String() string {
89	switch d {
90	case exprNoInfo:
91		return "noinfo"
92	case exprSimple:
93		return "simple"
94	case exprLiterals:
95		return "literals"
96	default:
97		return fmt.Sprintf("unknown<%d>", d)
98	}
99}
100
101func makeExprClassifier(names []*ir.Name) *exprClassifier {
102	m := make(map[*ir.Name]bool, len(names))
103	for _, n := range names {
104		m[n] = false
105	}
106	return &exprClassifier{
107		names:       m,
108		disposition: make(map[ir.Node]disp),
109	}
110}
111
112// Visit sets the classification for 'n' based on the previously
113// calculated classifications for n's children, as part of a bottom-up
114// walk over an expression tree.
115func (ec *exprClassifier) Visit(n ir.Node) {
116
117	ndisp := exprNoInfo
118
119	binparts := func(n ir.Node) (ir.Node, ir.Node) {
120		if lex, ok := n.(*ir.LogicalExpr); ok {
121			return lex.X, lex.Y
122		} else if bex, ok := n.(*ir.BinaryExpr); ok {
123			return bex.X, bex.Y
124		} else {
125			panic("bad")
126		}
127	}
128
129	t := n.Type()
130	if t == nil {
131		if debugTrace&debugTraceExprClassify != 0 {
132			fmt.Fprintf(os.Stderr, "=-= *** untyped op=%s\n",
133				n.Op().String())
134		}
135	} else if t.IsInteger() || t.IsString() || t.IsBoolean() || t.HasNil() {
136		switch n.Op() {
137		// FIXME: maybe add support for OADDSTR?
138		case ir.ONIL:
139			ndisp = exprLiterals
140
141		case ir.OLITERAL:
142			if _, ok := n.(*ir.BasicLit); ok {
143			} else {
144				panic("unexpected")
145			}
146			ndisp = exprLiterals
147
148		case ir.ONAME:
149			nn := n.(*ir.Name)
150			if _, ok := ec.names[nn]; ok {
151				ndisp = exprSimple
152				ec.names[nn] = true
153			} else {
154				sv := ir.StaticValue(n)
155				if sv.Op() == ir.ONAME {
156					nn = sv.(*ir.Name)
157				}
158				if _, ok := ec.names[nn]; ok {
159					ndisp = exprSimple
160					ec.names[nn] = true
161				}
162			}
163
164		case ir.ONOT,
165			ir.OPLUS,
166			ir.ONEG:
167			uex := n.(*ir.UnaryExpr)
168			ndisp = ec.getdisp(uex.X)
169
170		case ir.OEQ,
171			ir.ONE,
172			ir.OLT,
173			ir.OGT,
174			ir.OGE,
175			ir.OLE:
176			// compare ops
177			x, y := binparts(n)
178			ndisp = ec.dispmeet(x, y)
179			if debugTrace&debugTraceExprClassify != 0 {
180				fmt.Fprintf(os.Stderr, "=-= meet(%s,%s) = %s for op=%s\n",
181					ec.getdisp(x), ec.getdisp(y), ec.dispmeet(x, y),
182					n.Op().String())
183			}
184		case ir.OLSH,
185			ir.ORSH,
186			ir.ODIV,
187			ir.OMOD:
188			x, y := binparts(n)
189			if ec.getdisp(y) == exprLiterals {
190				ndisp = ec.dispmeet(x, y)
191			}
192
193		case ir.OADD,
194			ir.OSUB,
195			ir.OOR,
196			ir.OXOR,
197			ir.OMUL,
198			ir.OAND,
199			ir.OANDNOT,
200			ir.OANDAND,
201			ir.OOROR:
202			x, y := binparts(n)
203			if debugTrace&debugTraceExprClassify != 0 {
204				fmt.Fprintf(os.Stderr, "=-= meet(%s,%s) = %s for op=%s\n",
205					ec.getdisp(x), ec.getdisp(y), ec.dispmeet(x, y),
206					n.Op().String())
207			}
208			ndisp = ec.dispmeet(x, y)
209		}
210	}
211
212	if debugTrace&debugTraceExprClassify != 0 {
213		fmt.Fprintf(os.Stderr, "=-= op=%s disp=%v\n", n.Op().String(),
214			ndisp.String())
215	}
216
217	ec.disposition[n] = ndisp
218}
219
220func (ec *exprClassifier) getdisp(x ir.Node) disp {
221	if d, ok := ec.disposition[x]; ok {
222		return d
223	} else {
224		panic("missing node from disp table")
225	}
226}
227
228// dispmeet performs a "meet" operation on the data flow states of
229// node x and y (where the term "meet" is being drawn from traditional
230// lattice-theoretical data flow analysis terminology).
231func (ec *exprClassifier) dispmeet(x, y ir.Node) disp {
232	xd := ec.getdisp(x)
233	if xd == exprNoInfo {
234		return exprNoInfo
235	}
236	yd := ec.getdisp(y)
237	if yd == exprNoInfo {
238		return exprNoInfo
239	}
240	if xd == exprSimple || yd == exprSimple {
241		return exprSimple
242	}
243	if xd != exprLiterals || yd != exprLiterals {
244		panic("unexpected")
245	}
246	return exprLiterals
247}
248