1// Copyright 2016 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package ssa
6
7import (
8	"math/big"
9	"math/bits"
10)
11
12// So you want to compute x / c for some constant c?
13// Machine division instructions are slow, so we try to
14// compute this division with a multiplication + a few
15// other cheap instructions instead.
16// (We assume here that c != 0, +/- 1, or +/- 2^i.  Those
17// cases are easy to handle in different ways).
18
19// Technique from https://gmplib.org/~tege/divcnst-pldi94.pdf
20
21// First consider unsigned division.
22// Our strategy is to precompute 1/c then do
23//   ⎣x / c⎦ = ⎣x * (1/c)⎦.
24// 1/c is less than 1, so we can't compute it directly in
25// integer arithmetic.  Let's instead compute 2^e/c
26// for a value of e TBD (^ = exponentiation).  Then
27//   ⎣x / c⎦ = ⎣x * (2^e/c) / 2^e⎦.
28// Dividing by 2^e is easy.  2^e/c isn't an integer, unfortunately.
29// So we must approximate it.  Let's call its approximation m.
30// We'll then compute
31//   ⎣x * m / 2^e⎦
32// Which we want to be equal to ⎣x / c⎦ for 0 <= x < 2^n-1
33// where n is the word size.
34// Setting x = c gives us c * m >= 2^e.
35// We'll chose m = ⎡2^e/c⎤ to satisfy that equation.
36// What remains is to choose e.
37// Let m = 2^e/c + delta, 0 <= delta < 1
38//   ⎣x * (2^e/c + delta) / 2^e⎦
39//   ⎣x / c + x * delta / 2^e⎦
40// We must have x * delta / 2^e < 1/c so that this
41// additional term never rounds differently than ⎣x / c⎦ does.
42// Rearranging,
43//   2^e > x * delta * c
44// x can be at most 2^n-1 and delta can be at most 1.
45// So it is sufficient to have 2^e >= 2^n*c.
46// So we'll choose e = n + s, with s = ⎡log2(c)⎤.
47//
48// An additional complication arises because m has n+1 bits in it.
49// Hardware restricts us to n bit by n bit multiplies.
50// We divide into 3 cases:
51//
52// Case 1: m is even.
53//   ⎣x / c⎦ = ⎣x * m / 2^(n+s)⎦
54//   ⎣x / c⎦ = ⎣x * (m/2) / 2^(n+s-1)⎦
55//   ⎣x / c⎦ = ⎣x * (m/2) / 2^n / 2^(s-1)⎦
56//   ⎣x / c⎦ = ⎣⎣x * (m/2) / 2^n⎦ / 2^(s-1)⎦
57//   multiply + shift
58//
59// Case 2: c is even.
60//   ⎣x / c⎦ = ⎣(x/2) / (c/2)⎦
61//   ⎣x / c⎦ = ⎣⎣x/2⎦ / (c/2)⎦
62//     This is just the original problem, with x' = ⎣x/2⎦, c' = c/2, n' = n-1.
63//       s' = s-1
64//       m' = ⎡2^(n'+s')/c'⎤
65//          = ⎡2^(n+s-1)/c⎤
66//          = ⎡m/2⎤
67//   ⎣x / c⎦ = ⎣x' * m' / 2^(n'+s')⎦
68//   ⎣x / c⎦ = ⎣⎣x/2⎦ * ⎡m/2⎤ / 2^(n+s-2)⎦
69//   ⎣x / c⎦ = ⎣⎣⎣x/2⎦ * ⎡m/2⎤ / 2^n⎦ / 2^(s-2)⎦
70//   shift + multiply + shift
71//
72// Case 3: everything else
73//   let k = m - 2^n. k fits in n bits.
74//   ⎣x / c⎦ = ⎣x * m / 2^(n+s)⎦
75//   ⎣x / c⎦ = ⎣x * (2^n + k) / 2^(n+s)⎦
76//   ⎣x / c⎦ = ⎣(x + x * k / 2^n) / 2^s⎦
77//   ⎣x / c⎦ = ⎣(x + ⎣x * k / 2^n⎦) / 2^s⎦
78//   ⎣x / c⎦ = ⎣(x + ⎣x * k / 2^n⎦) / 2^s⎦
79//   ⎣x / c⎦ = ⎣⎣(x + ⎣x * k / 2^n⎦) / 2⎦ / 2^(s-1)⎦
80//   multiply + avg + shift
81//
82// These can be implemented in hardware using:
83//  ⎣a * b / 2^n⎦ - aka high n bits of an n-bit by n-bit multiply.
84//  ⎣(a+b) / 2⎦   - aka "average" of two n-bit numbers.
85//                  (Not just a regular add & shift because the intermediate result
86//                   a+b has n+1 bits in it.  Nevertheless, can be done
87//                   in 2 instructions on x86.)
88
89// umagicOK reports whether we should strength reduce a n-bit divide by c.
90func umagicOK(n uint, c int64) bool {
91	// Convert from ConstX auxint values to the real uint64 constant they represent.
92	d := uint64(c) << (64 - n) >> (64 - n)
93
94	// Doesn't work for 0.
95	// Don't use for powers of 2.
96	return d&(d-1) != 0
97}
98
99// umagicOKn reports whether we should strength reduce an unsigned n-bit divide by c.
100// We can strength reduce when c != 0 and c is not a power of two.
101func umagicOK8(c int8) bool   { return c&(c-1) != 0 }
102func umagicOK16(c int16) bool { return c&(c-1) != 0 }
103func umagicOK32(c int32) bool { return c&(c-1) != 0 }
104func umagicOK64(c int64) bool { return c&(c-1) != 0 }
105
106type umagicData struct {
107	s int64  // ⎡log2(c)⎤
108	m uint64 // ⎡2^(n+s)/c⎤ - 2^n
109}
110
111// umagic computes the constants needed to strength reduce unsigned n-bit divides by the constant uint64(c).
112// The return values satisfy for all 0 <= x < 2^n
113//
114//	floor(x / uint64(c)) = x * (m + 2^n) >> (n+s)
115func umagic(n uint, c int64) umagicData {
116	// Convert from ConstX auxint values to the real uint64 constant they represent.
117	d := uint64(c) << (64 - n) >> (64 - n)
118
119	C := new(big.Int).SetUint64(d)
120	s := C.BitLen()
121	M := big.NewInt(1)
122	M.Lsh(M, n+uint(s))     // 2^(n+s)
123	M.Add(M, C)             // 2^(n+s)+c
124	M.Sub(M, big.NewInt(1)) // 2^(n+s)+c-1
125	M.Div(M, C)             // ⎡2^(n+s)/c⎤
126	if M.Bit(int(n)) != 1 {
127		panic("n+1st bit isn't set")
128	}
129	M.SetBit(M, int(n), 0)
130	m := M.Uint64()
131	return umagicData{s: int64(s), m: m}
132}
133
134func umagic8(c int8) umagicData   { return umagic(8, int64(c)) }
135func umagic16(c int16) umagicData { return umagic(16, int64(c)) }
136func umagic32(c int32) umagicData { return umagic(32, int64(c)) }
137func umagic64(c int64) umagicData { return umagic(64, c) }
138
139// For signed division, we use a similar strategy.
140// First, we enforce a positive c.
141//   x / c = -(x / (-c))
142// This will require an additional Neg op for c<0.
143//
144// If x is positive we're in a very similar state
145// to the unsigned case above.  We define:
146//   s = ⎡log2(c)⎤-1
147//   m = ⎡2^(n+s)/c⎤
148// Then
149//   ⎣x / c⎦ = ⎣x * m / 2^(n+s)⎦
150// If x is negative we have
151//   ⎡x / c⎤ = ⎣x * m / 2^(n+s)⎦ + 1
152// (TODO: derivation?)
153//
154// The multiply is a bit odd, as it is a signed n-bit value
155// times an unsigned n-bit value.  For n smaller than the
156// word size, we can extend x and m appropriately and use the
157// signed multiply instruction.  For n == word size,
158// we must use the signed multiply high and correct
159// the result by adding x*2^n.
160//
161// Adding 1 if x<0 is done by subtracting x>>(n-1).
162
163func smagicOK(n uint, c int64) bool {
164	if c < 0 {
165		// Doesn't work for negative c.
166		return false
167	}
168	// Doesn't work for 0.
169	// Don't use it for powers of 2.
170	return c&(c-1) != 0
171}
172
173// smagicOKn reports whether we should strength reduce a signed n-bit divide by c.
174func smagicOK8(c int8) bool   { return smagicOK(8, int64(c)) }
175func smagicOK16(c int16) bool { return smagicOK(16, int64(c)) }
176func smagicOK32(c int32) bool { return smagicOK(32, int64(c)) }
177func smagicOK64(c int64) bool { return smagicOK(64, c) }
178
179type smagicData struct {
180	s int64  // ⎡log2(c)⎤-1
181	m uint64 // ⎡2^(n+s)/c⎤
182}
183
184// smagic computes the constants needed to strength reduce signed n-bit divides by the constant c.
185// Must have c>0.
186// The return values satisfy for all -2^(n-1) <= x < 2^(n-1)
187//
188//	trunc(x / c) = x * m >> (n+s) + (x < 0 ? 1 : 0)
189func smagic(n uint, c int64) smagicData {
190	C := new(big.Int).SetInt64(c)
191	s := C.BitLen() - 1
192	M := big.NewInt(1)
193	M.Lsh(M, n+uint(s))     // 2^(n+s)
194	M.Add(M, C)             // 2^(n+s)+c
195	M.Sub(M, big.NewInt(1)) // 2^(n+s)+c-1
196	M.Div(M, C)             // ⎡2^(n+s)/c⎤
197	if M.Bit(int(n)) != 0 {
198		panic("n+1st bit is set")
199	}
200	if M.Bit(int(n-1)) == 0 {
201		panic("nth bit is not set")
202	}
203	m := M.Uint64()
204	return smagicData{s: int64(s), m: m}
205}
206
207func smagic8(c int8) smagicData   { return smagic(8, int64(c)) }
208func smagic16(c int16) smagicData { return smagic(16, int64(c)) }
209func smagic32(c int32) smagicData { return smagic(32, int64(c)) }
210func smagic64(c int64) smagicData { return smagic(64, c) }
211
212// Divisibility x%c == 0 can be checked more efficiently than directly computing
213// the modulus x%c and comparing against 0.
214//
215// The same "Division by invariant integers using multiplication" paper
216// by Granlund and Montgomery referenced above briefly mentions this method
217// and it is further elaborated in "Hacker's Delight" by Warren Section 10-17
218//
219// The first thing to note is that for odd integers, exact division can be computed
220// by using the modular inverse with respect to the word size 2^n.
221//
222// Given c, compute m such that (c * m) mod 2^n == 1
223// Then if c divides x (x%c ==0), the quotient is given by q = x/c == x*m mod 2^n
224//
225// x can range from 0, c, 2c, 3c, ... ⎣(2^n - 1)/c⎦ * c the maximum multiple
226// Thus, x*m mod 2^n is 0, 1, 2, 3, ... ⎣(2^n - 1)/c⎦
227// i.e. the quotient takes all values from zero up to max = ⎣(2^n - 1)/c⎦
228//
229// If x is not divisible by c, then x*m mod 2^n must take some larger value than max.
230//
231// This gives x*m mod 2^n <= ⎣(2^n - 1)/c⎦ as a test for divisibility
232// involving one multiplication and compare.
233//
234// To extend this to even integers, consider c = d0 * 2^k where d0 is odd.
235// We can test whether x is divisible by both d0 and 2^k.
236// For d0, the test is the same as above.  Let m be such that m*d0 mod 2^n == 1
237// Then x*m mod 2^n <= ⎣(2^n - 1)/d0⎦ is the first test.
238// The test for divisibility by 2^k is a check for k trailing zeroes.
239// Note that since d0 is odd, m is odd and thus x*m will have the same number of
240// trailing zeroes as x.  So the two tests are,
241//
242// x*m mod 2^n <= ⎣(2^n - 1)/d0⎦
243// and x*m ends in k zero bits
244//
245// These can be combined into a single comparison by the following
246// (theorem ZRU in Hacker's Delight) for unsigned integers.
247//
248// x <= a and x ends in k zero bits if and only if RotRight(x ,k) <= ⎣a/(2^k)⎦
249// Where RotRight(x ,k) is right rotation of x by k bits.
250//
251// To prove the first direction, x <= a -> ⎣x/(2^k)⎦ <= ⎣a/(2^k)⎦
252// But since x ends in k zeroes all the rotated bits would be zero too.
253// So RotRight(x, k) == ⎣x/(2^k)⎦ <= ⎣a/(2^k)⎦
254//
255// If x does not end in k zero bits, then RotRight(x, k)
256// has some non-zero bits in the k highest bits.
257// ⎣x/(2^k)⎦ has all zeroes in the k highest bits,
258// so RotRight(x, k) > ⎣x/(2^k)⎦
259//
260// Finally, if x > a and has k trailing zero bits, then RotRight(x, k) == ⎣x/(2^k)⎦
261// and ⎣x/(2^k)⎦ must be greater than ⎣a/(2^k)⎦, that is the top n-k bits of x must
262// be greater than the top n-k bits of a because the rest of x bits are zero.
263//
264// So the two conditions about can be replaced with the single test
265//
266// RotRight(x*m mod 2^n, k) <= ⎣(2^n - 1)/c⎦
267//
268// Where d0*2^k was replaced by c on the right hand side.
269
270// udivisibleOK reports whether we should strength reduce an unsigned n-bit divisibility check by c.
271func udivisibleOK(n uint, c int64) bool {
272	// Convert from ConstX auxint values to the real uint64 constant they represent.
273	d := uint64(c) << (64 - n) >> (64 - n)
274
275	// Doesn't work for 0.
276	// Don't use for powers of 2.
277	return d&(d-1) != 0
278}
279
280func udivisibleOK8(c int8) bool   { return udivisibleOK(8, int64(c)) }
281func udivisibleOK16(c int16) bool { return udivisibleOK(16, int64(c)) }
282func udivisibleOK32(c int32) bool { return udivisibleOK(32, int64(c)) }
283func udivisibleOK64(c int64) bool { return udivisibleOK(64, c) }
284
285type udivisibleData struct {
286	k   int64  // trailingZeros(c)
287	m   uint64 // m * (c>>k) mod 2^n == 1 multiplicative inverse of odd portion modulo 2^n
288	max uint64 // ⎣(2^n - 1)/ c⎦ max value to for divisibility
289}
290
291func udivisible(n uint, c int64) udivisibleData {
292	// Convert from ConstX auxint values to the real uint64 constant they represent.
293	d := uint64(c) << (64 - n) >> (64 - n)
294
295	k := bits.TrailingZeros64(d)
296	d0 := d >> uint(k) // the odd portion of the divisor
297
298	mask := ^uint64(0) >> (64 - n)
299
300	// Calculate the multiplicative inverse via Newton's method.
301	// Quadratic convergence doubles the number of correct bits per iteration.
302	m := d0            // initial guess correct to 3-bits d0*d0 mod 8 == 1
303	m = m * (2 - m*d0) // 6-bits
304	m = m * (2 - m*d0) // 12-bits
305	m = m * (2 - m*d0) // 24-bits
306	m = m * (2 - m*d0) // 48-bits
307	m = m * (2 - m*d0) // 96-bits >= 64-bits
308	m = m & mask
309
310	max := mask / d
311
312	return udivisibleData{
313		k:   int64(k),
314		m:   m,
315		max: max,
316	}
317}
318
319func udivisible8(c int8) udivisibleData   { return udivisible(8, int64(c)) }
320func udivisible16(c int16) udivisibleData { return udivisible(16, int64(c)) }
321func udivisible32(c int32) udivisibleData { return udivisible(32, int64(c)) }
322func udivisible64(c int64) udivisibleData { return udivisible(64, c) }
323
324// For signed integers, a similar method follows.
325//
326// Given c > 1 and odd, compute m such that (c * m) mod 2^n == 1
327// Then if c divides x (x%c ==0), the quotient is given by q = x/c == x*m mod 2^n
328//
329// x can range from ⎡-2^(n-1)/c⎤ * c, ... -c, 0, c, ...  ⎣(2^(n-1) - 1)/c⎦ * c
330// Thus, x*m mod 2^n is ⎡-2^(n-1)/c⎤, ... -2, -1, 0, 1, 2, ... ⎣(2^(n-1) - 1)/c⎦
331//
332// So, x is a multiple of c if and only if:
333// ⎡-2^(n-1)/c⎤ <= x*m mod 2^n <= ⎣(2^(n-1) - 1)/c⎦
334//
335// Since c > 1 and odd, this can be simplified by
336// ⎡-2^(n-1)/c⎤ == ⎡(-2^(n-1) + 1)/c⎤ == -⎣(2^(n-1) - 1)/c⎦
337//
338// -⎣(2^(n-1) - 1)/c⎦ <= x*m mod 2^n <= ⎣(2^(n-1) - 1)/c⎦
339//
340// To extend this to even integers, consider c = d0 * 2^k where d0 is odd.
341// We can test whether x is divisible by both d0 and 2^k.
342//
343// Let m be such that (d0 * m) mod 2^n == 1.
344// Let q = x*m mod 2^n. Then c divides x if:
345//
346// -⎣(2^(n-1) - 1)/d0⎦ <= q <= ⎣(2^(n-1) - 1)/d0⎦ and q ends in at least k 0-bits
347//
348// To transform this to a single comparison, we use the following theorem (ZRS in Hacker's Delight).
349//
350// For a >= 0 the following conditions are equivalent:
351// 1) -a <= x <= a and x ends in at least k 0-bits
352// 2) RotRight(x+a', k) <= ⎣2a'/2^k⎦
353//
354// Where a' = a & -2^k (a with its right k bits set to zero)
355//
356// To see that 1 & 2 are equivalent, note that -a <= x <= a is equivalent to
357// -a' <= x <= a' if and only if x ends in at least k 0-bits.  Adding -a' to each side gives,
358// 0 <= x + a' <= 2a' and x + a' ends in at least k 0-bits if and only if x does since a' has
359// k 0-bits by definition.  We can use theorem ZRU above with x -> x + a' and a -> 2a' giving 1) == 2).
360//
361// Let m be such that (d0 * m) mod 2^n == 1.
362// Let q = x*m mod 2^n.
363// Let a' = ⎣(2^(n-1) - 1)/d0⎦ & -2^k
364//
365// Then the divisibility test is:
366//
367// RotRight(q+a', k) <= ⎣2a'/2^k⎦
368//
369// Note that the calculation is performed using unsigned integers.
370// Since a' can have n-1 bits, 2a' may have n bits and there is no risk of overflow.
371
372// sdivisibleOK reports whether we should strength reduce a signed n-bit divisibility check by c.
373func sdivisibleOK(n uint, c int64) bool {
374	if c < 0 {
375		// Doesn't work for negative c.
376		return false
377	}
378	// Doesn't work for 0.
379	// Don't use it for powers of 2.
380	return c&(c-1) != 0
381}
382
383func sdivisibleOK8(c int8) bool   { return sdivisibleOK(8, int64(c)) }
384func sdivisibleOK16(c int16) bool { return sdivisibleOK(16, int64(c)) }
385func sdivisibleOK32(c int32) bool { return sdivisibleOK(32, int64(c)) }
386func sdivisibleOK64(c int64) bool { return sdivisibleOK(64, c) }
387
388type sdivisibleData struct {
389	k   int64  // trailingZeros(c)
390	m   uint64 // m * (c>>k) mod 2^n == 1 multiplicative inverse of odd portion modulo 2^n
391	a   uint64 // ⎣(2^(n-1) - 1)/ (c>>k)⎦ & -(1<<k) additive constant
392	max uint64 // ⎣(2 a) / (1<<k)⎦ max value to for divisibility
393}
394
395func sdivisible(n uint, c int64) sdivisibleData {
396	d := uint64(c)
397	k := bits.TrailingZeros64(d)
398	d0 := d >> uint(k) // the odd portion of the divisor
399
400	mask := ^uint64(0) >> (64 - n)
401
402	// Calculate the multiplicative inverse via Newton's method.
403	// Quadratic convergence doubles the number of correct bits per iteration.
404	m := d0            // initial guess correct to 3-bits d0*d0 mod 8 == 1
405	m = m * (2 - m*d0) // 6-bits
406	m = m * (2 - m*d0) // 12-bits
407	m = m * (2 - m*d0) // 24-bits
408	m = m * (2 - m*d0) // 48-bits
409	m = m * (2 - m*d0) // 96-bits >= 64-bits
410	m = m & mask
411
412	a := ((mask >> 1) / d0) & -(1 << uint(k))
413	max := (2 * a) >> uint(k)
414
415	return sdivisibleData{
416		k:   int64(k),
417		m:   m,
418		a:   a,
419		max: max,
420	}
421}
422
423func sdivisible8(c int8) sdivisibleData   { return sdivisible(8, int64(c)) }
424func sdivisible16(c int16) sdivisibleData { return sdivisible(16, int64(c)) }
425func sdivisible32(c int32) sdivisibleData { return sdivisible(32, int64(c)) }
426func sdivisible64(c int64) sdivisibleData { return sdivisible(64, c) }
427