1// Copyright 2020 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package ssa
6
7import (
8	"math"
9	"math/rand"
10	"testing"
11)
12
13var (
14	x64   int64  = math.MaxInt64 - 2
15	x64b  int64  = math.MaxInt64 - 2
16	x64c  int64  = math.MaxInt64 - 2
17	y64   int64  = math.MinInt64 + 1
18	x32   int32  = math.MaxInt32 - 2
19	x32b  int32  = math.MaxInt32 - 2
20	x32c  int32  = math.MaxInt32 - 2
21	y32   int32  = math.MinInt32 + 1
22	one64 int64  = 1
23	one32 int32  = 1
24	v64   int64  = 11 // ensure it's not 2**n +/- 1
25	v64_n int64  = -11
26	v32   int32  = 11
27	v32_n int32  = -11
28	uv32  uint32 = 19
29	uz    uint8  = 1 // for lowering to SLL/SRL/SRA
30)
31
32var crTests = []struct {
33	name string
34	tf   func(t *testing.T)
35}{
36	{"AddConst64", testAddConst64},
37	{"AddConst32", testAddConst32},
38	{"AddVar64", testAddVar64},
39	{"AddVar64Cset", testAddVar64Cset},
40	{"AddVar32", testAddVar32},
41	{"MAddVar64", testMAddVar64},
42	{"MAddVar32", testMAddVar32},
43	{"MSubVar64", testMSubVar64},
44	{"MSubVar32", testMSubVar32},
45	{"AddShift32", testAddShift32},
46	{"SubShift32", testSubShift32},
47}
48
49var crBenches = []struct {
50	name string
51	bf   func(b *testing.B)
52}{
53	{"SoloJump", benchSoloJump},
54	{"CombJump", benchCombJump},
55}
56
57// Test int32/int64's add/sub/madd/msub operations with boundary values to
58// ensure the optimization to 'comparing to zero' expressions of if-statements
59// yield expected results.
60// 32 rewriting rules are covered. At least two scenarios for "Canonicalize
61// the order of arguments to comparisons", which helps with CSE, are covered.
62// The tedious if-else structures are necessary to ensure all concerned rules
63// and machine code sequences are covered.
64// It's for arm64 initially, please see https://github.com/golang/go/issues/38740
65func TestCondRewrite(t *testing.T) {
66	for _, test := range crTests {
67		t.Run(test.name, test.tf)
68	}
69}
70
71// Profile the aforementioned optimization from two angles:
72//
73//	SoloJump: generated branching code has one 'jump', for '<' and '>='
74//	CombJump: generated branching code has two consecutive 'jump', for '<=' and '>'
75//
76// We expect that 'CombJump' is generally on par with the non-optimized code, and
77// 'SoloJump' demonstrates some improvement.
78// It's for arm64 initially, please see https://github.com/golang/go/issues/38740
79func BenchmarkCondRewrite(b *testing.B) {
80	for _, bench := range crBenches {
81		b.Run(bench.name, bench.bf)
82	}
83}
84
85// var +/- const
86func testAddConst64(t *testing.T) {
87	if x64+11 < 0 {
88	} else {
89		t.Errorf("'%#x + 11 < 0' failed", x64)
90	}
91
92	if x64+13 <= 0 {
93	} else {
94		t.Errorf("'%#x + 13 <= 0' failed", x64)
95	}
96
97	if y64-11 > 0 {
98	} else {
99		t.Errorf("'%#x - 11 > 0' failed", y64)
100	}
101
102	if y64-13 >= 0 {
103	} else {
104		t.Errorf("'%#x - 13 >= 0' failed", y64)
105	}
106
107	if x64+19 > 0 {
108		t.Errorf("'%#x + 19 > 0' failed", x64)
109	}
110
111	if x64+23 >= 0 {
112		t.Errorf("'%#x + 23 >= 0' failed", x64)
113	}
114
115	if y64-19 < 0 {
116		t.Errorf("'%#x - 19 < 0' failed", y64)
117	}
118
119	if y64-23 <= 0 {
120		t.Errorf("'%#x - 23 <= 0' failed", y64)
121	}
122}
123
124// 32-bit var +/- const
125func testAddConst32(t *testing.T) {
126	if x32+11 < 0 {
127	} else {
128		t.Errorf("'%#x + 11 < 0' failed", x32)
129	}
130
131	if x32+13 <= 0 {
132	} else {
133		t.Errorf("'%#x + 13 <= 0' failed", x32)
134	}
135
136	if y32-11 > 0 {
137	} else {
138		t.Errorf("'%#x - 11 > 0' failed", y32)
139	}
140
141	if y32-13 >= 0 {
142	} else {
143		t.Errorf("'%#x - 13 >= 0' failed", y32)
144	}
145
146	if x32+19 > 0 {
147		t.Errorf("'%#x + 19 > 0' failed", x32)
148	}
149
150	if x32+23 >= 0 {
151		t.Errorf("'%#x + 23 >= 0' failed", x32)
152	}
153
154	if y32-19 < 0 {
155		t.Errorf("'%#x - 19 < 0' failed", y32)
156	}
157
158	if y32-23 <= 0 {
159		t.Errorf("'%#x - 23 <= 0' failed", y32)
160	}
161}
162
163// var + var
164func testAddVar64(t *testing.T) {
165	if x64+v64 < 0 {
166	} else {
167		t.Errorf("'%#x + %#x < 0' failed", x64, v64)
168	}
169
170	if x64+v64 <= 0 {
171	} else {
172		t.Errorf("'%#x + %#x <= 0' failed", x64, v64)
173	}
174
175	if y64+v64_n > 0 {
176	} else {
177		t.Errorf("'%#x + %#x > 0' failed", y64, v64_n)
178	}
179
180	if y64+v64_n >= 0 {
181	} else {
182		t.Errorf("'%#x + %#x >= 0' failed", y64, v64_n)
183	}
184
185	if x64+v64 > 0 {
186		t.Errorf("'%#x + %#x > 0' failed", x64, v64)
187	}
188
189	if x64+v64 >= 0 {
190		t.Errorf("'%#x + %#x >= 0' failed", x64, v64)
191	}
192
193	if y64+v64_n < 0 {
194		t.Errorf("'%#x + %#x < 0' failed", y64, v64_n)
195	}
196
197	if y64+v64_n <= 0 {
198		t.Errorf("'%#x + %#x <= 0' failed", y64, v64_n)
199	}
200}
201
202// var + var, cset
203func testAddVar64Cset(t *testing.T) {
204	var a int
205	if x64+v64 < 0 {
206		a = 1
207	}
208	if a != 1 {
209		t.Errorf("'%#x + %#x < 0' failed", x64, v64)
210	}
211
212	a = 0
213	if y64+v64_n >= 0 {
214		a = 1
215	}
216	if a != 1 {
217		t.Errorf("'%#x + %#x >= 0' failed", y64, v64_n)
218	}
219
220	a = 1
221	if x64+v64 >= 0 {
222		a = 0
223	}
224	if a == 0 {
225		t.Errorf("'%#x + %#x >= 0' failed", x64, v64)
226	}
227
228	a = 1
229	if y64+v64_n < 0 {
230		a = 0
231	}
232	if a == 0 {
233		t.Errorf("'%#x + %#x < 0' failed", y64, v64_n)
234	}
235}
236
237// 32-bit var+var
238func testAddVar32(t *testing.T) {
239	if x32+v32 < 0 {
240	} else {
241		t.Errorf("'%#x + %#x < 0' failed", x32, v32)
242	}
243
244	if x32+v32 <= 0 {
245	} else {
246		t.Errorf("'%#x + %#x <= 0' failed", x32, v32)
247	}
248
249	if y32+v32_n > 0 {
250	} else {
251		t.Errorf("'%#x + %#x > 0' failed", y32, v32_n)
252	}
253
254	if y32+v32_n >= 0 {
255	} else {
256		t.Errorf("'%#x + %#x >= 0' failed", y32, v32_n)
257	}
258
259	if x32+v32 > 0 {
260		t.Errorf("'%#x + %#x > 0' failed", x32, v32)
261	}
262
263	if x32+v32 >= 0 {
264		t.Errorf("'%#x + %#x >= 0' failed", x32, v32)
265	}
266
267	if y32+v32_n < 0 {
268		t.Errorf("'%#x + %#x < 0' failed", y32, v32_n)
269	}
270
271	if y32+v32_n <= 0 {
272		t.Errorf("'%#x + %#x <= 0' failed", y32, v32_n)
273	}
274}
275
276// multiply-add
277func testMAddVar64(t *testing.T) {
278	if x64+v64*one64 < 0 {
279	} else {
280		t.Errorf("'%#x + %#x*1 < 0' failed", x64, v64)
281	}
282
283	if x64+v64*one64 <= 0 {
284	} else {
285		t.Errorf("'%#x + %#x*1 <= 0' failed", x64, v64)
286	}
287
288	if y64+v64_n*one64 > 0 {
289	} else {
290		t.Errorf("'%#x + %#x*1 > 0' failed", y64, v64_n)
291	}
292
293	if y64+v64_n*one64 >= 0 {
294	} else {
295		t.Errorf("'%#x + %#x*1 >= 0' failed", y64, v64_n)
296	}
297
298	if x64+v64*one64 > 0 {
299		t.Errorf("'%#x + %#x*1 > 0' failed", x64, v64)
300	}
301
302	if x64+v64*one64 >= 0 {
303		t.Errorf("'%#x + %#x*1 >= 0' failed", x64, v64)
304	}
305
306	if y64+v64_n*one64 < 0 {
307		t.Errorf("'%#x + %#x*1 < 0' failed", y64, v64_n)
308	}
309
310	if y64+v64_n*one64 <= 0 {
311		t.Errorf("'%#x + %#x*1 <= 0' failed", y64, v64_n)
312	}
313}
314
315// 32-bit multiply-add
316func testMAddVar32(t *testing.T) {
317	if x32+v32*one32 < 0 {
318	} else {
319		t.Errorf("'%#x + %#x*1 < 0' failed", x32, v32)
320	}
321
322	if x32+v32*one32 <= 0 {
323	} else {
324		t.Errorf("'%#x + %#x*1 <= 0' failed", x32, v32)
325	}
326
327	if y32+v32_n*one32 > 0 {
328	} else {
329		t.Errorf("'%#x + %#x*1 > 0' failed", y32, v32_n)
330	}
331
332	if y32+v32_n*one32 >= 0 {
333	} else {
334		t.Errorf("'%#x + %#x*1 >= 0' failed", y32, v32_n)
335	}
336
337	if x32+v32*one32 > 0 {
338		t.Errorf("'%#x + %#x*1 > 0' failed", x32, v32)
339	}
340
341	if x32+v32*one32 >= 0 {
342		t.Errorf("'%#x + %#x*1 >= 0' failed", x32, v32)
343	}
344
345	if y32+v32_n*one32 < 0 {
346		t.Errorf("'%#x + %#x*1 < 0' failed", y32, v32_n)
347	}
348
349	if y32+v32_n*one32 <= 0 {
350		t.Errorf("'%#x + %#x*1 <= 0' failed", y32, v32_n)
351	}
352}
353
354// multiply-sub
355func testMSubVar64(t *testing.T) {
356	if x64-v64_n*one64 < 0 {
357	} else {
358		t.Errorf("'%#x - %#x*1 < 0' failed", x64, v64_n)
359	}
360
361	if x64-v64_n*one64 <= 0 {
362	} else {
363		t.Errorf("'%#x - %#x*1 <= 0' failed", x64, v64_n)
364	}
365
366	if y64-v64*one64 > 0 {
367	} else {
368		t.Errorf("'%#x - %#x*1 > 0' failed", y64, v64)
369	}
370
371	if y64-v64*one64 >= 0 {
372	} else {
373		t.Errorf("'%#x - %#x*1 >= 0' failed", y64, v64)
374	}
375
376	if x64-v64_n*one64 > 0 {
377		t.Errorf("'%#x - %#x*1 > 0' failed", x64, v64_n)
378	}
379
380	if x64-v64_n*one64 >= 0 {
381		t.Errorf("'%#x - %#x*1 >= 0' failed", x64, v64_n)
382	}
383
384	if y64-v64*one64 < 0 {
385		t.Errorf("'%#x - %#x*1 < 0' failed", y64, v64)
386	}
387
388	if y64-v64*one64 <= 0 {
389		t.Errorf("'%#x - %#x*1 <= 0' failed", y64, v64)
390	}
391
392	if x64-x64b*one64 < 0 {
393		t.Errorf("'%#x - %#x*1 < 0' failed", x64, x64b)
394	}
395
396	if x64-x64b*one64 >= 0 {
397	} else {
398		t.Errorf("'%#x - %#x*1 >= 0' failed", x64, x64b)
399	}
400}
401
402// 32-bit multiply-sub
403func testMSubVar32(t *testing.T) {
404	if x32-v32_n*one32 < 0 {
405	} else {
406		t.Errorf("'%#x - %#x*1 < 0' failed", x32, v32_n)
407	}
408
409	if x32-v32_n*one32 <= 0 {
410	} else {
411		t.Errorf("'%#x - %#x*1 <= 0' failed", x32, v32_n)
412	}
413
414	if y32-v32*one32 > 0 {
415	} else {
416		t.Errorf("'%#x - %#x*1 > 0' failed", y32, v32)
417	}
418
419	if y32-v32*one32 >= 0 {
420	} else {
421		t.Errorf("'%#x - %#x*1 >= 0' failed", y32, v32)
422	}
423
424	if x32-v32_n*one32 > 0 {
425		t.Errorf("'%#x - %#x*1 > 0' failed", x32, v32_n)
426	}
427
428	if x32-v32_n*one32 >= 0 {
429		t.Errorf("'%#x - %#x*1 >= 0' failed", x32, v32_n)
430	}
431
432	if y32-v32*one32 < 0 {
433		t.Errorf("'%#x - %#x*1 < 0' failed", y32, v32)
434	}
435
436	if y32-v32*one32 <= 0 {
437		t.Errorf("'%#x - %#x*1 <= 0' failed", y32, v32)
438	}
439
440	if x32-x32b*one32 < 0 {
441		t.Errorf("'%#x - %#x*1 < 0' failed", x32, x32b)
442	}
443
444	if x32-x32b*one32 >= 0 {
445	} else {
446		t.Errorf("'%#x - %#x*1 >= 0' failed", x32, x32b)
447	}
448}
449
450// 32-bit ADDshift, pick up 1~2 scenarios randomly for each condition
451func testAddShift32(t *testing.T) {
452	if x32+v32<<1 < 0 {
453	} else {
454		t.Errorf("'%#x + %#x<<%#x < 0' failed", x32, v32, 1)
455	}
456
457	if x32+v32>>1 <= 0 {
458	} else {
459		t.Errorf("'%#x + %#x>>%#x <= 0' failed", x32, v32, 1)
460	}
461
462	if x32+int32(uv32>>1) > 0 {
463		t.Errorf("'%#x + int32(%#x>>%#x) > 0' failed", x32, uv32, 1)
464	}
465
466	if x32+v32<<uz >= 0 {
467		t.Errorf("'%#x + %#x<<%#x >= 0' failed", x32, v32, uz)
468	}
469
470	if x32+v32>>uz > 0 {
471		t.Errorf("'%#x + %#x>>%#x > 0' failed", x32, v32, uz)
472	}
473
474	if x32+int32(uv32>>uz) < 0 {
475	} else {
476		t.Errorf("'%#x + int32(%#x>>%#x) < 0' failed", x32, uv32, uz)
477	}
478}
479
480// 32-bit SUBshift, pick up 1~2 scenarios randomly for each condition
481func testSubShift32(t *testing.T) {
482	if y32-v32<<1 > 0 {
483	} else {
484		t.Errorf("'%#x - %#x<<%#x > 0' failed", y32, v32, 1)
485	}
486
487	if y32-v32>>1 < 0 {
488		t.Errorf("'%#x - %#x>>%#x < 0' failed", y32, v32, 1)
489	}
490
491	if y32-int32(uv32>>1) >= 0 {
492	} else {
493		t.Errorf("'%#x - int32(%#x>>%#x) >= 0' failed", y32, uv32, 1)
494	}
495
496	if y32-v32<<uz < 0 {
497		t.Errorf("'%#x - %#x<<%#x < 0' failed", y32, v32, uz)
498	}
499
500	if y32-v32>>uz >= 0 {
501	} else {
502		t.Errorf("'%#x - %#x>>%#x >= 0' failed", y32, v32, uz)
503	}
504
505	if y32-int32(uv32>>uz) <= 0 {
506		t.Errorf("'%#x - int32(%#x>>%#x) <= 0' failed", y32, uv32, uz)
507	}
508}
509
510var rnd = rand.New(rand.NewSource(0))
511var sink int64
512
513func benchSoloJump(b *testing.B) {
514	r1 := x64
515	r2 := x64b
516	r3 := x64c
517	r4 := y64
518	d := rnd.Int63n(10)
519
520	// 6 out 10 conditions evaluate to true
521	for i := 0; i < b.N; i++ {
522		if r1+r2 < 0 {
523			d *= 2
524			d /= 2
525		}
526
527		if r1+r3 >= 0 {
528			d *= 2
529			d /= 2
530		}
531
532		if r1+r2*one64 < 0 {
533			d *= 2
534			d /= 2
535		}
536
537		if r2+r3*one64 >= 0 {
538			d *= 2
539			d /= 2
540		}
541
542		if r1-r2*v64 >= 0 {
543			d *= 2
544			d /= 2
545		}
546
547		if r3-r4*v64 < 0 {
548			d *= 2
549			d /= 2
550		}
551
552		if r1+11 < 0 {
553			d *= 2
554			d /= 2
555		}
556
557		if r1+13 >= 0 {
558			d *= 2
559			d /= 2
560		}
561
562		if r4-17 < 0 {
563			d *= 2
564			d /= 2
565		}
566
567		if r4-19 >= 0 {
568			d *= 2
569			d /= 2
570		}
571	}
572	sink = d
573}
574
575func benchCombJump(b *testing.B) {
576	r1 := x64
577	r2 := x64b
578	r3 := x64c
579	r4 := y64
580	d := rnd.Int63n(10)
581
582	// 6 out 10 conditions evaluate to true
583	for i := 0; i < b.N; i++ {
584		if r1+r2 <= 0 {
585			d *= 2
586			d /= 2
587		}
588
589		if r1+r3 > 0 {
590			d *= 2
591			d /= 2
592		}
593
594		if r1+r2*one64 <= 0 {
595			d *= 2
596			d /= 2
597		}
598
599		if r2+r3*one64 > 0 {
600			d *= 2
601			d /= 2
602		}
603
604		if r1-r2*v64 > 0 {
605			d *= 2
606			d /= 2
607		}
608
609		if r3-r4*v64 <= 0 {
610			d *= 2
611			d /= 2
612		}
613
614		if r1+11 <= 0 {
615			d *= 2
616			d /= 2
617		}
618
619		if r1+13 > 0 {
620			d *= 2
621			d /= 2
622		}
623
624		if r4-17 <= 0 {
625			d *= 2
626			d /= 2
627		}
628
629		if r4-19 > 0 {
630			d *= 2
631			d /= 2
632		}
633	}
634	sink = d
635}
636