1// Copyright 2023 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package ssa
6
7import (
8	"fmt"
9)
10
11// ----------------------------------------------------------------------------
12// Sparse Conditional Constant Propagation
13//
14// Described in
15// Mark N. Wegman, F. Kenneth Zadeck: Constant Propagation with Conditional Branches.
16// TOPLAS 1991.
17//
18// This algorithm uses three level lattice for SSA value
19//
20//      Top        undefined
21//     / | \
22// .. 1  2  3 ..   constant
23//     \ | /
24//     Bottom      not constant
25//
26// It starts with optimistically assuming that all SSA values are initially Top
27// and then propagates constant facts only along reachable control flow paths.
28// Since some basic blocks are not visited yet, corresponding inputs of phi become
29// Top, we use the meet(phi) to compute its lattice.
30//
31// 	  Top ∩ any = any
32// 	  Bottom ∩ any = Bottom
33// 	  ConstantA ∩ ConstantA = ConstantA
34// 	  ConstantA ∩ ConstantB = Bottom
35//
36// Each lattice value is lowered most twice(Top to Constant, Constant to Bottom)
37// due to lattice depth, resulting in a fast convergence speed of the algorithm.
38// In this way, sccp can discover optimization opportunities that cannot be found
39// by just combining constant folding and constant propagation and dead code
40// elimination separately.
41
42// Three level lattice holds compile time knowledge about SSA value
43const (
44	top      int8 = iota // undefined
45	constant             // constant
46	bottom               // not a constant
47)
48
49type lattice struct {
50	tag int8   // lattice type
51	val *Value // constant value
52}
53
54type worklist struct {
55	f            *Func               // the target function to be optimized out
56	edges        []Edge              // propagate constant facts through edges
57	uses         []*Value            // re-visiting set
58	visited      map[Edge]bool       // visited edges
59	latticeCells map[*Value]lattice  // constant lattices
60	defUse       map[*Value][]*Value // def-use chains for some values
61	defBlock     map[*Value][]*Block // use blocks of def
62	visitedBlock []bool              // visited block
63}
64
65// sccp stands for sparse conditional constant propagation, it propagates constants
66// through CFG conditionally and applies constant folding, constant replacement and
67// dead code elimination all together.
68func sccp(f *Func) {
69	var t worklist
70	t.f = f
71	t.edges = make([]Edge, 0)
72	t.visited = make(map[Edge]bool)
73	t.edges = append(t.edges, Edge{f.Entry, 0})
74	t.defUse = make(map[*Value][]*Value)
75	t.defBlock = make(map[*Value][]*Block)
76	t.latticeCells = make(map[*Value]lattice)
77	t.visitedBlock = f.Cache.allocBoolSlice(f.NumBlocks())
78	defer f.Cache.freeBoolSlice(t.visitedBlock)
79
80	// build it early since we rely heavily on the def-use chain later
81	t.buildDefUses()
82
83	// pick up either an edge or SSA value from worklist, process it
84	for {
85		if len(t.edges) > 0 {
86			edge := t.edges[0]
87			t.edges = t.edges[1:]
88			if _, exist := t.visited[edge]; !exist {
89				dest := edge.b
90				destVisited := t.visitedBlock[dest.ID]
91
92				// mark edge as visited
93				t.visited[edge] = true
94				t.visitedBlock[dest.ID] = true
95				for _, val := range dest.Values {
96					if val.Op == OpPhi || !destVisited {
97						t.visitValue(val)
98					}
99				}
100				// propagates constants facts through CFG, taking condition test
101				// into account
102				if !destVisited {
103					t.propagate(dest)
104				}
105			}
106			continue
107		}
108		if len(t.uses) > 0 {
109			use := t.uses[0]
110			t.uses = t.uses[1:]
111			t.visitValue(use)
112			continue
113		}
114		break
115	}
116
117	// apply optimizations based on discovered constants
118	constCnt, rewireCnt := t.replaceConst()
119	if f.pass.debug > 0 {
120		if constCnt > 0 || rewireCnt > 0 {
121			fmt.Printf("Phase SCCP for %v : %v constants, %v dce\n", f.Name, constCnt, rewireCnt)
122		}
123	}
124}
125
126func equals(a, b lattice) bool {
127	if a == b {
128		// fast path
129		return true
130	}
131	if a.tag != b.tag {
132		return false
133	}
134	if a.tag == constant {
135		// The same content of const value may be different, we should
136		// compare with auxInt instead
137		v1 := a.val
138		v2 := b.val
139		if v1.Op == v2.Op && v1.AuxInt == v2.AuxInt {
140			return true
141		} else {
142			return false
143		}
144	}
145	return true
146}
147
148// possibleConst checks if Value can be folded to const. For those Values that can
149// never become constants(e.g. StaticCall), we don't make futile efforts.
150func possibleConst(val *Value) bool {
151	if isConst(val) {
152		return true
153	}
154	switch val.Op {
155	case OpCopy:
156		return true
157	case OpPhi:
158		return true
159	case
160		// negate
161		OpNeg8, OpNeg16, OpNeg32, OpNeg64, OpNeg32F, OpNeg64F,
162		OpCom8, OpCom16, OpCom32, OpCom64,
163		// math
164		OpFloor, OpCeil, OpTrunc, OpRoundToEven, OpSqrt,
165		// conversion
166		OpTrunc16to8, OpTrunc32to8, OpTrunc32to16, OpTrunc64to8,
167		OpTrunc64to16, OpTrunc64to32, OpCvt32to32F, OpCvt32to64F,
168		OpCvt64to32F, OpCvt64to64F, OpCvt32Fto32, OpCvt32Fto64,
169		OpCvt64Fto32, OpCvt64Fto64, OpCvt32Fto64F, OpCvt64Fto32F,
170		OpCvtBoolToUint8,
171		OpZeroExt8to16, OpZeroExt8to32, OpZeroExt8to64, OpZeroExt16to32,
172		OpZeroExt16to64, OpZeroExt32to64, OpSignExt8to16, OpSignExt8to32,
173		OpSignExt8to64, OpSignExt16to32, OpSignExt16to64, OpSignExt32to64,
174		// bit
175		OpCtz8, OpCtz16, OpCtz32, OpCtz64,
176		// mask
177		OpSlicemask,
178		// safety check
179		OpIsNonNil,
180		// not
181		OpNot:
182		return true
183	case
184		// add
185		OpAdd64, OpAdd32, OpAdd16, OpAdd8,
186		OpAdd32F, OpAdd64F,
187		// sub
188		OpSub64, OpSub32, OpSub16, OpSub8,
189		OpSub32F, OpSub64F,
190		// mul
191		OpMul64, OpMul32, OpMul16, OpMul8,
192		OpMul32F, OpMul64F,
193		// div
194		OpDiv32F, OpDiv64F,
195		OpDiv8, OpDiv16, OpDiv32, OpDiv64,
196		OpDiv8u, OpDiv16u, OpDiv32u, OpDiv64u,
197		OpMod8, OpMod16, OpMod32, OpMod64,
198		OpMod8u, OpMod16u, OpMod32u, OpMod64u,
199		// compare
200		OpEq64, OpEq32, OpEq16, OpEq8,
201		OpEq32F, OpEq64F,
202		OpLess64, OpLess32, OpLess16, OpLess8,
203		OpLess64U, OpLess32U, OpLess16U, OpLess8U,
204		OpLess32F, OpLess64F,
205		OpLeq64, OpLeq32, OpLeq16, OpLeq8,
206		OpLeq64U, OpLeq32U, OpLeq16U, OpLeq8U,
207		OpLeq32F, OpLeq64F,
208		OpEqB, OpNeqB,
209		// shift
210		OpLsh64x64, OpRsh64x64, OpRsh64Ux64, OpLsh32x64,
211		OpRsh32x64, OpRsh32Ux64, OpLsh16x64, OpRsh16x64,
212		OpRsh16Ux64, OpLsh8x64, OpRsh8x64, OpRsh8Ux64,
213		// safety check
214		OpIsInBounds, OpIsSliceInBounds,
215		// bit
216		OpAnd8, OpAnd16, OpAnd32, OpAnd64,
217		OpOr8, OpOr16, OpOr32, OpOr64,
218		OpXor8, OpXor16, OpXor32, OpXor64:
219		return true
220	default:
221		return false
222	}
223}
224
225func (t *worklist) getLatticeCell(val *Value) lattice {
226	if !possibleConst(val) {
227		// they are always worst
228		return lattice{bottom, nil}
229	}
230	lt, exist := t.latticeCells[val]
231	if !exist {
232		return lattice{top, nil} // optimistically for un-visited value
233	}
234	return lt
235}
236
237func isConst(val *Value) bool {
238	switch val.Op {
239	case OpConst64, OpConst32, OpConst16, OpConst8,
240		OpConstBool, OpConst32F, OpConst64F:
241		return true
242	default:
243		return false
244	}
245}
246
247// buildDefUses builds def-use chain for some values early, because once the
248// lattice of a value is changed, we need to update lattices of use. But we don't
249// need all uses of it, only uses that can become constants would be added into
250// re-visit worklist since no matter how many times they are revisited, uses which
251// can't become constants lattice remains unchanged, i.e. Bottom.
252func (t *worklist) buildDefUses() {
253	for _, block := range t.f.Blocks {
254		for _, val := range block.Values {
255			for _, arg := range val.Args {
256				// find its uses, only uses that can become constants take into account
257				if possibleConst(arg) && possibleConst(val) {
258					if _, exist := t.defUse[arg]; !exist {
259						t.defUse[arg] = make([]*Value, 0, arg.Uses)
260					}
261					t.defUse[arg] = append(t.defUse[arg], val)
262				}
263			}
264		}
265		for _, ctl := range block.ControlValues() {
266			// for control values that can become constants, find their use blocks
267			if possibleConst(ctl) {
268				t.defBlock[ctl] = append(t.defBlock[ctl], block)
269			}
270		}
271	}
272}
273
274// addUses finds all uses of value and appends them into work list for further process
275func (t *worklist) addUses(val *Value) {
276	for _, use := range t.defUse[val] {
277		if val == use {
278			// Phi may refer to itself as uses, ignore them to avoid re-visiting phi
279			// for performance reason
280			continue
281		}
282		t.uses = append(t.uses, use)
283	}
284	for _, block := range t.defBlock[val] {
285		if t.visitedBlock[block.ID] {
286			t.propagate(block)
287		}
288	}
289}
290
291// meet meets all of phi arguments and computes result lattice
292func (t *worklist) meet(val *Value) lattice {
293	optimisticLt := lattice{top, nil}
294	for i := 0; i < len(val.Args); i++ {
295		edge := Edge{val.Block, i}
296		// If incoming edge for phi is not visited, assume top optimistically.
297		// According to rules of meet:
298		// 		Top ∩ any = any
299		// Top participates in meet() but does not affect the result, so here
300		// we will ignore Top and only take other lattices into consideration.
301		if _, exist := t.visited[edge]; exist {
302			lt := t.getLatticeCell(val.Args[i])
303			if lt.tag == constant {
304				if optimisticLt.tag == top {
305					optimisticLt = lt
306				} else {
307					if !equals(optimisticLt, lt) {
308						// ConstantA ∩ ConstantB = Bottom
309						return lattice{bottom, nil}
310					}
311				}
312			} else if lt.tag == bottom {
313				// Bottom ∩ any = Bottom
314				return lattice{bottom, nil}
315			} else {
316				// Top ∩ any = any
317			}
318		} else {
319			// Top ∩ any = any
320		}
321	}
322
323	// ConstantA ∩ ConstantA = ConstantA or Top ∩ any = any
324	return optimisticLt
325}
326
327func computeLattice(f *Func, val *Value, args ...*Value) lattice {
328	// In general, we need to perform constant evaluation based on constant args:
329	//
330	//  res := lattice{constant, nil}
331	// 	switch op {
332	// 	case OpAdd16:
333	//		res.val = newConst(argLt1.val.AuxInt16() + argLt2.val.AuxInt16())
334	// 	case OpAdd32:
335	// 		res.val = newConst(argLt1.val.AuxInt32() + argLt2.val.AuxInt32())
336	//	case OpDiv8:
337	//		if !isDivideByZero(argLt2.val.AuxInt8()) {
338	//			res.val = newConst(argLt1.val.AuxInt8() / argLt2.val.AuxInt8())
339	//		}
340	//  ...
341	// 	}
342	//
343	// However, this would create a huge switch for all opcodes that can be
344	// evaluated during compile time. Moreover, some operations can be evaluated
345	// only if its arguments satisfy additional conditions(e.g. divide by zero).
346	// It's fragile and error-prone. We did a trick by reusing the existing rules
347	// in generic rules for compile-time evaluation. But generic rules rewrite
348	// original value, this behavior is undesired, because the lattice of values
349	// may change multiple times, once it was rewritten, we lose the opportunity
350	// to change it permanently, which can lead to errors. For example, We cannot
351	// change its value immediately after visiting Phi, because some of its input
352	// edges may still not be visited at this moment.
353	constValue := f.newValue(val.Op, val.Type, f.Entry, val.Pos)
354	constValue.AddArgs(args...)
355	matched := rewriteValuegeneric(constValue)
356	if matched {
357		if isConst(constValue) {
358			return lattice{constant, constValue}
359		}
360	}
361	// Either we can not match generic rules for given value or it does not
362	// satisfy additional constraints(e.g. divide by zero), in these cases, clean
363	// up temporary value immediately in case they are not dominated by their args.
364	constValue.reset(OpInvalid)
365	return lattice{bottom, nil}
366}
367
368func (t *worklist) visitValue(val *Value) {
369	if !possibleConst(val) {
370		// fast fail for always worst Values, i.e. there is no lowering happen
371		// on them, their lattices must be initially worse Bottom.
372		return
373	}
374
375	oldLt := t.getLatticeCell(val)
376	defer func() {
377		// re-visit all uses of value if its lattice is changed
378		newLt := t.getLatticeCell(val)
379		if !equals(newLt, oldLt) {
380			if int8(oldLt.tag) > int8(newLt.tag) {
381				t.f.Fatalf("Must lower lattice\n")
382			}
383			t.addUses(val)
384		}
385	}()
386
387	switch val.Op {
388	// they are constant values, aren't they?
389	case OpConst64, OpConst32, OpConst16, OpConst8,
390		OpConstBool, OpConst32F, OpConst64F: //TODO: support ConstNil ConstString etc
391		t.latticeCells[val] = lattice{constant, val}
392	// lattice value of copy(x) actually means lattice value of (x)
393	case OpCopy:
394		t.latticeCells[val] = t.getLatticeCell(val.Args[0])
395	// phi should be processed specially
396	case OpPhi:
397		t.latticeCells[val] = t.meet(val)
398	// fold 1-input operations:
399	case
400		// negate
401		OpNeg8, OpNeg16, OpNeg32, OpNeg64, OpNeg32F, OpNeg64F,
402		OpCom8, OpCom16, OpCom32, OpCom64,
403		// math
404		OpFloor, OpCeil, OpTrunc, OpRoundToEven, OpSqrt,
405		// conversion
406		OpTrunc16to8, OpTrunc32to8, OpTrunc32to16, OpTrunc64to8,
407		OpTrunc64to16, OpTrunc64to32, OpCvt32to32F, OpCvt32to64F,
408		OpCvt64to32F, OpCvt64to64F, OpCvt32Fto32, OpCvt32Fto64,
409		OpCvt64Fto32, OpCvt64Fto64, OpCvt32Fto64F, OpCvt64Fto32F,
410		OpCvtBoolToUint8,
411		OpZeroExt8to16, OpZeroExt8to32, OpZeroExt8to64, OpZeroExt16to32,
412		OpZeroExt16to64, OpZeroExt32to64, OpSignExt8to16, OpSignExt8to32,
413		OpSignExt8to64, OpSignExt16to32, OpSignExt16to64, OpSignExt32to64,
414		// bit
415		OpCtz8, OpCtz16, OpCtz32, OpCtz64,
416		// mask
417		OpSlicemask,
418		// safety check
419		OpIsNonNil,
420		// not
421		OpNot:
422		lt1 := t.getLatticeCell(val.Args[0])
423
424		if lt1.tag == constant {
425			// here we take a shortcut by reusing generic rules to fold constants
426			t.latticeCells[val] = computeLattice(t.f, val, lt1.val)
427		} else {
428			t.latticeCells[val] = lattice{lt1.tag, nil}
429		}
430	// fold 2-input operations
431	case
432		// add
433		OpAdd64, OpAdd32, OpAdd16, OpAdd8,
434		OpAdd32F, OpAdd64F,
435		// sub
436		OpSub64, OpSub32, OpSub16, OpSub8,
437		OpSub32F, OpSub64F,
438		// mul
439		OpMul64, OpMul32, OpMul16, OpMul8,
440		OpMul32F, OpMul64F,
441		// div
442		OpDiv32F, OpDiv64F,
443		OpDiv8, OpDiv16, OpDiv32, OpDiv64,
444		OpDiv8u, OpDiv16u, OpDiv32u, OpDiv64u, //TODO: support div128u
445		// mod
446		OpMod8, OpMod16, OpMod32, OpMod64,
447		OpMod8u, OpMod16u, OpMod32u, OpMod64u,
448		// compare
449		OpEq64, OpEq32, OpEq16, OpEq8,
450		OpEq32F, OpEq64F,
451		OpLess64, OpLess32, OpLess16, OpLess8,
452		OpLess64U, OpLess32U, OpLess16U, OpLess8U,
453		OpLess32F, OpLess64F,
454		OpLeq64, OpLeq32, OpLeq16, OpLeq8,
455		OpLeq64U, OpLeq32U, OpLeq16U, OpLeq8U,
456		OpLeq32F, OpLeq64F,
457		OpEqB, OpNeqB,
458		// shift
459		OpLsh64x64, OpRsh64x64, OpRsh64Ux64, OpLsh32x64,
460		OpRsh32x64, OpRsh32Ux64, OpLsh16x64, OpRsh16x64,
461		OpRsh16Ux64, OpLsh8x64, OpRsh8x64, OpRsh8Ux64,
462		// safety check
463		OpIsInBounds, OpIsSliceInBounds,
464		// bit
465		OpAnd8, OpAnd16, OpAnd32, OpAnd64,
466		OpOr8, OpOr16, OpOr32, OpOr64,
467		OpXor8, OpXor16, OpXor32, OpXor64:
468		lt1 := t.getLatticeCell(val.Args[0])
469		lt2 := t.getLatticeCell(val.Args[1])
470
471		if lt1.tag == constant && lt2.tag == constant {
472			// here we take a shortcut by reusing generic rules to fold constants
473			t.latticeCells[val] = computeLattice(t.f, val, lt1.val, lt2.val)
474		} else {
475			if lt1.tag == bottom || lt2.tag == bottom {
476				t.latticeCells[val] = lattice{bottom, nil}
477			} else {
478				t.latticeCells[val] = lattice{top, nil}
479			}
480		}
481	default:
482		// Any other type of value cannot be a constant, they are always worst(Bottom)
483	}
484}
485
486// propagate propagates constants facts through CFG. If the block has single successor,
487// add the successor anyway. If the block has multiple successors, only add the
488// branch destination corresponding to lattice value of condition value.
489func (t *worklist) propagate(block *Block) {
490	switch block.Kind {
491	case BlockExit, BlockRet, BlockRetJmp, BlockInvalid:
492		// control flow ends, do nothing then
493		break
494	case BlockDefer:
495		// we know nothing about control flow, add all branch destinations
496		t.edges = append(t.edges, block.Succs...)
497	case BlockFirst:
498		fallthrough // always takes the first branch
499	case BlockPlain:
500		t.edges = append(t.edges, block.Succs[0])
501	case BlockIf, BlockJumpTable:
502		cond := block.ControlValues()[0]
503		condLattice := t.getLatticeCell(cond)
504		if condLattice.tag == bottom {
505			// we know nothing about control flow, add all branch destinations
506			t.edges = append(t.edges, block.Succs...)
507		} else if condLattice.tag == constant {
508			// add branchIdx destinations depends on its condition
509			var branchIdx int64
510			if block.Kind == BlockIf {
511				branchIdx = 1 - condLattice.val.AuxInt
512			} else {
513				branchIdx = condLattice.val.AuxInt
514			}
515			t.edges = append(t.edges, block.Succs[branchIdx])
516		} else {
517			// condition value is not visited yet, don't propagate it now
518		}
519	default:
520		t.f.Fatalf("All kind of block should be processed above.")
521	}
522}
523
524// rewireSuccessor rewires corresponding successors according to constant value
525// discovered by previous analysis. As the result, some successors become unreachable
526// and thus can be removed in further deadcode phase
527func rewireSuccessor(block *Block, constVal *Value) bool {
528	switch block.Kind {
529	case BlockIf:
530		block.removeEdge(int(constVal.AuxInt))
531		block.Kind = BlockPlain
532		block.Likely = BranchUnknown
533		block.ResetControls()
534		return true
535	case BlockJumpTable:
536		// Remove everything but the known taken branch.
537		idx := int(constVal.AuxInt)
538		if idx < 0 || idx >= len(block.Succs) {
539			// This can only happen in unreachable code,
540			// as an invariant of jump tables is that their
541			// input index is in range.
542			// See issue 64826.
543			return false
544		}
545		block.swapSuccessorsByIdx(0, idx)
546		for len(block.Succs) > 1 {
547			block.removeEdge(1)
548		}
549		block.Kind = BlockPlain
550		block.Likely = BranchUnknown
551		block.ResetControls()
552		return true
553	default:
554		return false
555	}
556}
557
558// replaceConst will replace non-constant values that have been proven by sccp
559// to be constants.
560func (t *worklist) replaceConst() (int, int) {
561	constCnt, rewireCnt := 0, 0
562	for val, lt := range t.latticeCells {
563		if lt.tag == constant {
564			if !isConst(val) {
565				if t.f.pass.debug > 0 {
566					fmt.Printf("Replace %v with %v\n", val.LongString(), lt.val.LongString())
567				}
568				val.reset(lt.val.Op)
569				val.AuxInt = lt.val.AuxInt
570				constCnt++
571			}
572			// If const value controls this block, rewires successors according to its value
573			ctrlBlock := t.defBlock[val]
574			for _, block := range ctrlBlock {
575				if rewireSuccessor(block, lt.val) {
576					rewireCnt++
577					if t.f.pass.debug > 0 {
578						fmt.Printf("Rewire %v %v successors\n", block.Kind, block)
579					}
580				}
581			}
582		}
583	}
584	return constCnt, rewireCnt
585}
586