1// Copyright 2015 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package ssa
6
7import (
8	"cmd/compile/internal/types"
9	"cmd/internal/src"
10	"fmt"
11	"sort"
12)
13
14// cse does common-subexpression elimination on the Function.
15// Values are just relinked, nothing is deleted. A subsequent deadcode
16// pass is required to actually remove duplicate expressions.
17func cse(f *Func) {
18	// Two values are equivalent if they satisfy the following definition:
19	// equivalent(v, w):
20	//   v.op == w.op
21	//   v.type == w.type
22	//   v.aux == w.aux
23	//   v.auxint == w.auxint
24	//   len(v.args) == len(w.args)
25	//   v.block == w.block if v.op == OpPhi
26	//   equivalent(v.args[i], w.args[i]) for i in 0..len(v.args)-1
27
28	// The algorithm searches for a partition of f's values into
29	// equivalence classes using the above definition.
30	// It starts with a coarse partition and iteratively refines it
31	// until it reaches a fixed point.
32
33	// Make initial coarse partitions by using a subset of the conditions above.
34	a := f.Cache.allocValueSlice(f.NumValues())
35	defer func() { f.Cache.freeValueSlice(a) }() // inside closure to use final value of a
36	a = a[:0]
37	if f.auxmap == nil {
38		f.auxmap = auxmap{}
39	}
40	for _, b := range f.Blocks {
41		for _, v := range b.Values {
42			if v.Type.IsMemory() {
43				continue // memory values can never cse
44			}
45			if f.auxmap[v.Aux] == 0 {
46				f.auxmap[v.Aux] = int32(len(f.auxmap)) + 1
47			}
48			a = append(a, v)
49		}
50	}
51	partition := partitionValues(a, f.auxmap)
52
53	// map from value id back to eqclass id
54	valueEqClass := f.Cache.allocIDSlice(f.NumValues())
55	defer f.Cache.freeIDSlice(valueEqClass)
56	for _, b := range f.Blocks {
57		for _, v := range b.Values {
58			// Use negative equivalence class #s for unique values.
59			valueEqClass[v.ID] = -v.ID
60		}
61	}
62	var pNum ID = 1
63	for _, e := range partition {
64		if f.pass.debug > 1 && len(e) > 500 {
65			fmt.Printf("CSE.large partition (%d): ", len(e))
66			for j := 0; j < 3; j++ {
67				fmt.Printf("%s ", e[j].LongString())
68			}
69			fmt.Println()
70		}
71
72		for _, v := range e {
73			valueEqClass[v.ID] = pNum
74		}
75		if f.pass.debug > 2 && len(e) > 1 {
76			fmt.Printf("CSE.partition #%d:", pNum)
77			for _, v := range e {
78				fmt.Printf(" %s", v.String())
79			}
80			fmt.Printf("\n")
81		}
82		pNum++
83	}
84
85	// Split equivalence classes at points where they have
86	// non-equivalent arguments.  Repeat until we can't find any
87	// more splits.
88	var splitPoints []int
89	byArgClass := new(partitionByArgClass) // reusable partitionByArgClass to reduce allocations
90	for {
91		changed := false
92
93		// partition can grow in the loop. By not using a range loop here,
94		// we process new additions as they arrive, avoiding O(n^2) behavior.
95		for i := 0; i < len(partition); i++ {
96			e := partition[i]
97
98			if opcodeTable[e[0].Op].commutative {
99				// Order the first two args before comparison.
100				for _, v := range e {
101					if valueEqClass[v.Args[0].ID] > valueEqClass[v.Args[1].ID] {
102						v.Args[0], v.Args[1] = v.Args[1], v.Args[0]
103					}
104				}
105			}
106
107			// Sort by eq class of arguments.
108			byArgClass.a = e
109			byArgClass.eqClass = valueEqClass
110			sort.Sort(byArgClass)
111
112			// Find split points.
113			splitPoints = append(splitPoints[:0], 0)
114			for j := 1; j < len(e); j++ {
115				v, w := e[j-1], e[j]
116				// Note: commutative args already correctly ordered by byArgClass.
117				eqArgs := true
118				for k, a := range v.Args {
119					b := w.Args[k]
120					if valueEqClass[a.ID] != valueEqClass[b.ID] {
121						eqArgs = false
122						break
123					}
124				}
125				if !eqArgs {
126					splitPoints = append(splitPoints, j)
127				}
128			}
129			if len(splitPoints) == 1 {
130				continue // no splits, leave equivalence class alone.
131			}
132
133			// Move another equivalence class down in place of e.
134			partition[i] = partition[len(partition)-1]
135			partition = partition[:len(partition)-1]
136			i--
137
138			// Add new equivalence classes for the parts of e we found.
139			splitPoints = append(splitPoints, len(e))
140			for j := 0; j < len(splitPoints)-1; j++ {
141				f := e[splitPoints[j]:splitPoints[j+1]]
142				if len(f) == 1 {
143					// Don't add singletons.
144					valueEqClass[f[0].ID] = -f[0].ID
145					continue
146				}
147				for _, v := range f {
148					valueEqClass[v.ID] = pNum
149				}
150				pNum++
151				partition = append(partition, f)
152			}
153			changed = true
154		}
155
156		if !changed {
157			break
158		}
159	}
160
161	sdom := f.Sdom()
162
163	// Compute substitutions we would like to do. We substitute v for w
164	// if v and w are in the same equivalence class and v dominates w.
165	rewrite := f.Cache.allocValueSlice(f.NumValues())
166	defer f.Cache.freeValueSlice(rewrite)
167	byDom := new(partitionByDom) // reusable partitionByDom to reduce allocs
168	for _, e := range partition {
169		byDom.a = e
170		byDom.sdom = sdom
171		sort.Sort(byDom)
172		for i := 0; i < len(e)-1; i++ {
173			// e is sorted by domorder, so a maximal dominant element is first in the slice
174			v := e[i]
175			if v == nil {
176				continue
177			}
178
179			e[i] = nil
180			// Replace all elements of e which v dominates
181			for j := i + 1; j < len(e); j++ {
182				w := e[j]
183				if w == nil {
184					continue
185				}
186				if sdom.IsAncestorEq(v.Block, w.Block) {
187					rewrite[w.ID] = v
188					e[j] = nil
189				} else {
190					// e is sorted by domorder, so v.Block doesn't dominate any subsequent blocks in e
191					break
192				}
193			}
194		}
195	}
196
197	rewrites := int64(0)
198
199	// Apply substitutions
200	for _, b := range f.Blocks {
201		for _, v := range b.Values {
202			for i, w := range v.Args {
203				if x := rewrite[w.ID]; x != nil {
204					if w.Pos.IsStmt() == src.PosIsStmt {
205						// about to lose a statement marker, w
206						// w is an input to v; if they're in the same block
207						// and the same line, v is a good-enough new statement boundary.
208						if w.Block == v.Block && w.Pos.Line() == v.Pos.Line() {
209							v.Pos = v.Pos.WithIsStmt()
210							w.Pos = w.Pos.WithNotStmt()
211						} // TODO and if this fails?
212					}
213					v.SetArg(i, x)
214					rewrites++
215				}
216			}
217		}
218		for i, v := range b.ControlValues() {
219			if x := rewrite[v.ID]; x != nil {
220				if v.Op == OpNilCheck {
221					// nilcheck pass will remove the nil checks and log
222					// them appropriately, so don't mess with them here.
223					continue
224				}
225				b.ReplaceControl(i, x)
226			}
227		}
228	}
229
230	if f.pass.stats > 0 {
231		f.LogStat("CSE REWRITES", rewrites)
232	}
233}
234
235// An eqclass approximates an equivalence class. During the
236// algorithm it may represent the union of several of the
237// final equivalence classes.
238type eqclass []*Value
239
240// partitionValues partitions the values into equivalence classes
241// based on having all the following features match:
242//   - opcode
243//   - type
244//   - auxint
245//   - aux
246//   - nargs
247//   - block # if a phi op
248//   - first two arg's opcodes and auxint
249//   - NOT first two arg's aux; that can break CSE.
250//
251// partitionValues returns a list of equivalence classes, each
252// being a sorted by ID list of *Values. The eqclass slices are
253// backed by the same storage as the input slice.
254// Equivalence classes of size 1 are ignored.
255func partitionValues(a []*Value, auxIDs auxmap) []eqclass {
256	sort.Sort(sortvalues{a, auxIDs})
257
258	var partition []eqclass
259	for len(a) > 0 {
260		v := a[0]
261		j := 1
262		for ; j < len(a); j++ {
263			w := a[j]
264			if cmpVal(v, w, auxIDs) != types.CMPeq {
265				break
266			}
267		}
268		if j > 1 {
269			partition = append(partition, a[:j])
270		}
271		a = a[j:]
272	}
273
274	return partition
275}
276func lt2Cmp(isLt bool) types.Cmp {
277	if isLt {
278		return types.CMPlt
279	}
280	return types.CMPgt
281}
282
283type auxmap map[Aux]int32
284
285func cmpVal(v, w *Value, auxIDs auxmap) types.Cmp {
286	// Try to order these comparison by cost (cheaper first)
287	if v.Op != w.Op {
288		return lt2Cmp(v.Op < w.Op)
289	}
290	if v.AuxInt != w.AuxInt {
291		return lt2Cmp(v.AuxInt < w.AuxInt)
292	}
293	if len(v.Args) != len(w.Args) {
294		return lt2Cmp(len(v.Args) < len(w.Args))
295	}
296	if v.Op == OpPhi && v.Block != w.Block {
297		return lt2Cmp(v.Block.ID < w.Block.ID)
298	}
299	if v.Type.IsMemory() {
300		// We will never be able to CSE two values
301		// that generate memory.
302		return lt2Cmp(v.ID < w.ID)
303	}
304	// OpSelect is a pseudo-op. We need to be more aggressive
305	// regarding CSE to keep multiple OpSelect's of the same
306	// argument from existing.
307	if v.Op != OpSelect0 && v.Op != OpSelect1 && v.Op != OpSelectN {
308		if tc := v.Type.Compare(w.Type); tc != types.CMPeq {
309			return tc
310		}
311	}
312
313	if v.Aux != w.Aux {
314		if v.Aux == nil {
315			return types.CMPlt
316		}
317		if w.Aux == nil {
318			return types.CMPgt
319		}
320		return lt2Cmp(auxIDs[v.Aux] < auxIDs[w.Aux])
321	}
322
323	return types.CMPeq
324}
325
326// Sort values to make the initial partition.
327type sortvalues struct {
328	a      []*Value // array of values
329	auxIDs auxmap   // aux -> aux ID map
330}
331
332func (sv sortvalues) Len() int      { return len(sv.a) }
333func (sv sortvalues) Swap(i, j int) { sv.a[i], sv.a[j] = sv.a[j], sv.a[i] }
334func (sv sortvalues) Less(i, j int) bool {
335	v := sv.a[i]
336	w := sv.a[j]
337	if cmp := cmpVal(v, w, sv.auxIDs); cmp != types.CMPeq {
338		return cmp == types.CMPlt
339	}
340
341	// Sort by value ID last to keep the sort result deterministic.
342	return v.ID < w.ID
343}
344
345type partitionByDom struct {
346	a    []*Value // array of values
347	sdom SparseTree
348}
349
350func (sv partitionByDom) Len() int      { return len(sv.a) }
351func (sv partitionByDom) Swap(i, j int) { sv.a[i], sv.a[j] = sv.a[j], sv.a[i] }
352func (sv partitionByDom) Less(i, j int) bool {
353	v := sv.a[i]
354	w := sv.a[j]
355	return sv.sdom.domorder(v.Block) < sv.sdom.domorder(w.Block)
356}
357
358type partitionByArgClass struct {
359	a       []*Value // array of values
360	eqClass []ID     // equivalence class IDs of values
361}
362
363func (sv partitionByArgClass) Len() int      { return len(sv.a) }
364func (sv partitionByArgClass) Swap(i, j int) { sv.a[i], sv.a[j] = sv.a[j], sv.a[i] }
365func (sv partitionByArgClass) Less(i, j int) bool {
366	v := sv.a[i]
367	w := sv.a[j]
368	for i, a := range v.Args {
369		b := w.Args[i]
370		if sv.eqClass[a.ID] < sv.eqClass[b.ID] {
371			return true
372		}
373		if sv.eqClass[a.ID] > sv.eqClass[b.ID] {
374			return false
375		}
376	}
377	return false
378}
379