1// Copyright 2016 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// This program generates a test to verify that the standard arithmetic
6// operators properly handle const cases. The test file should be
7// generated with a known working version of go.
8// launch with `go run arithConstGen.go` a file called arithConst.go
9// will be written into the parent directory containing the tests
10
11package main
12
13import (
14	"bytes"
15	"fmt"
16	"go/format"
17	"log"
18	"strings"
19	"text/template"
20)
21
22type op struct {
23	name, symbol string
24}
25type szD struct {
26	name   string
27	sn     string
28	u      []uint64
29	i      []int64
30	oponly string
31}
32
33var szs = []szD{
34	{name: "uint64", sn: "64", u: []uint64{0, 1, 4294967296, 0x8000000000000000, 0xffffFFFFffffFFFF}},
35	{name: "uint64", sn: "64", u: []uint64{3, 5, 7, 9, 10, 11, 13, 19, 21, 25, 27, 37, 41, 45, 73, 81}, oponly: "mul"},
36
37	{name: "int64", sn: "64", i: []int64{-0x8000000000000000, -0x7FFFFFFFFFFFFFFF,
38		-4294967296, -1, 0, 1, 4294967296, 0x7FFFFFFFFFFFFFFE, 0x7FFFFFFFFFFFFFFF}},
39	{name: "int64", sn: "64", i: []int64{-9, -5, -3, 3, 5, 7, 9, 10, 11, 13, 19, 21, 25, 27, 37, 41, 45, 73, 81}, oponly: "mul"},
40
41	{name: "uint32", sn: "32", u: []uint64{0, 1, 4294967295}},
42	{name: "uint32", sn: "32", u: []uint64{3, 5, 7, 9, 10, 11, 13, 19, 21, 25, 27, 37, 41, 45, 73, 81}, oponly: "mul"},
43
44	{name: "int32", sn: "32", i: []int64{-0x80000000, -0x7FFFFFFF, -1, 0,
45		1, 0x7FFFFFFF}},
46	{name: "int32", sn: "32", i: []int64{-9, -5, -3, 3, 5, 7, 9, 10, 11, 13, 19, 21, 25, 27, 37, 41, 45, 73, 81}, oponly: "mul"},
47
48	{name: "uint16", sn: "16", u: []uint64{0, 1, 65535}},
49	{name: "int16", sn: "16", i: []int64{-32768, -32767, -1, 0, 1, 32766, 32767}},
50
51	{name: "uint8", sn: "8", u: []uint64{0, 1, 255}},
52	{name: "int8", sn: "8", i: []int64{-128, -127, -1, 0, 1, 126, 127}},
53}
54
55var ops = []op{
56	{"add", "+"},
57	{"sub", "-"},
58	{"div", "/"},
59	{"mul", "*"},
60	{"lsh", "<<"},
61	{"rsh", ">>"},
62	{"mod", "%"},
63	{"and", "&"},
64	{"or", "|"},
65	{"xor", "^"},
66}
67
68// compute the result of i op j, cast as type t.
69func ansU(i, j uint64, t, op string) string {
70	var ans uint64
71	switch op {
72	case "+":
73		ans = i + j
74	case "-":
75		ans = i - j
76	case "*":
77		ans = i * j
78	case "/":
79		if j != 0 {
80			ans = i / j
81		}
82	case "%":
83		if j != 0 {
84			ans = i % j
85		}
86	case "<<":
87		ans = i << j
88	case ">>":
89		ans = i >> j
90	case "&":
91		ans = i & j
92	case "|":
93		ans = i | j
94	case "^":
95		ans = i ^ j
96	}
97	switch t {
98	case "uint32":
99		ans = uint64(uint32(ans))
100	case "uint16":
101		ans = uint64(uint16(ans))
102	case "uint8":
103		ans = uint64(uint8(ans))
104	}
105	return fmt.Sprintf("%d", ans)
106}
107
108// compute the result of i op j, cast as type t.
109func ansS(i, j int64, t, op string) string {
110	var ans int64
111	switch op {
112	case "+":
113		ans = i + j
114	case "-":
115		ans = i - j
116	case "*":
117		ans = i * j
118	case "/":
119		if j != 0 {
120			ans = i / j
121		}
122	case "%":
123		if j != 0 {
124			ans = i % j
125		}
126	case "<<":
127		ans = i << uint64(j)
128	case ">>":
129		ans = i >> uint64(j)
130	case "&":
131		ans = i & j
132	case "|":
133		ans = i | j
134	case "^":
135		ans = i ^ j
136	}
137	switch t {
138	case "int32":
139		ans = int64(int32(ans))
140	case "int16":
141		ans = int64(int16(ans))
142	case "int8":
143		ans = int64(int8(ans))
144	}
145	return fmt.Sprintf("%d", ans)
146}
147
148func main() {
149	w := new(bytes.Buffer)
150	fmt.Fprintf(w, "// Code generated by gen/arithConstGen.go. DO NOT EDIT.\n\n")
151	fmt.Fprintf(w, "package main;\n")
152	fmt.Fprintf(w, "import \"testing\"\n")
153
154	fncCnst1 := template.Must(template.New("fnc").Parse(
155		`//go:noinline
156func {{.Name}}_{{.Type_}}_{{.FNumber}}(a {{.Type_}}) {{.Type_}} { return a {{.Symbol}} {{.Number}} }
157`))
158	fncCnst2 := template.Must(template.New("fnc").Parse(
159		`//go:noinline
160func {{.Name}}_{{.FNumber}}_{{.Type_}}(a {{.Type_}}) {{.Type_}} { return {{.Number}} {{.Symbol}} a }
161`))
162
163	type fncData struct {
164		Name, Type_, Symbol, FNumber, Number string
165	}
166
167	for _, s := range szs {
168		for _, o := range ops {
169			if s.oponly != "" && s.oponly != o.name {
170				continue
171			}
172			fd := fncData{o.name, s.name, o.symbol, "", ""}
173
174			// unsigned test cases
175			if len(s.u) > 0 {
176				for _, i := range s.u {
177					fd.Number = fmt.Sprintf("%d", i)
178					fd.FNumber = strings.Replace(fd.Number, "-", "Neg", -1)
179
180					// avoid division by zero
181					if o.name != "mod" && o.name != "div" || i != 0 {
182						// introduce uint64 cast for rhs shift operands
183						// if they are too large for default uint type
184						number := fd.Number
185						if (o.name == "lsh" || o.name == "rsh") && uint64(uint32(i)) != i {
186							fd.Number = fmt.Sprintf("uint64(%s)", number)
187						}
188						fncCnst1.Execute(w, fd)
189						fd.Number = number
190					}
191
192					fncCnst2.Execute(w, fd)
193				}
194			}
195
196			// signed test cases
197			if len(s.i) > 0 {
198				// don't generate tests for shifts by signed integers
199				if o.name == "lsh" || o.name == "rsh" {
200					continue
201				}
202				for _, i := range s.i {
203					fd.Number = fmt.Sprintf("%d", i)
204					fd.FNumber = strings.Replace(fd.Number, "-", "Neg", -1)
205
206					// avoid division by zero
207					if o.name != "mod" && o.name != "div" || i != 0 {
208						fncCnst1.Execute(w, fd)
209					}
210					fncCnst2.Execute(w, fd)
211				}
212			}
213		}
214	}
215
216	vrf1 := template.Must(template.New("vrf1").Parse(`
217		test_{{.Size}}{fn: {{.Name}}_{{.FNumber}}_{{.Type_}}, fnname: "{{.Name}}_{{.FNumber}}_{{.Type_}}", in: {{.Input}}, want: {{.Ans}}},`))
218
219	vrf2 := template.Must(template.New("vrf2").Parse(`
220		test_{{.Size}}{fn: {{.Name}}_{{.Type_}}_{{.FNumber}}, fnname: "{{.Name}}_{{.Type_}}_{{.FNumber}}", in: {{.Input}}, want: {{.Ans}}},`))
221
222	type cfncData struct {
223		Size, Name, Type_, Symbol, FNumber, Number string
224		Ans, Input                                 string
225	}
226	for _, s := range szs {
227		fmt.Fprintf(w, `
228type test_%[1]s%[2]s struct {
229	fn func (%[1]s) %[1]s
230	fnname string
231	in %[1]s
232	want %[1]s
233}
234`, s.name, s.oponly)
235		fmt.Fprintf(w, "var tests_%[1]s%[2]s =[]test_%[1]s {\n\n", s.name, s.oponly)
236
237		if len(s.u) > 0 {
238			for _, o := range ops {
239				if s.oponly != "" && s.oponly != o.name {
240					continue
241				}
242				fd := cfncData{s.name, o.name, s.name, o.symbol, "", "", "", ""}
243				for _, i := range s.u {
244					fd.Number = fmt.Sprintf("%d", i)
245					fd.FNumber = strings.Replace(fd.Number, "-", "Neg", -1)
246
247					// unsigned
248					for _, j := range s.u {
249
250						if o.name != "mod" && o.name != "div" || j != 0 {
251							fd.Ans = ansU(i, j, s.name, o.symbol)
252							fd.Input = fmt.Sprintf("%d", j)
253							if err := vrf1.Execute(w, fd); err != nil {
254								panic(err)
255							}
256						}
257
258						if o.name != "mod" && o.name != "div" || i != 0 {
259							fd.Ans = ansU(j, i, s.name, o.symbol)
260							fd.Input = fmt.Sprintf("%d", j)
261							if err := vrf2.Execute(w, fd); err != nil {
262								panic(err)
263							}
264						}
265
266					}
267				}
268
269			}
270		}
271
272		// signed
273		if len(s.i) > 0 {
274			for _, o := range ops {
275				if s.oponly != "" && s.oponly != o.name {
276					continue
277				}
278				// don't generate tests for shifts by signed integers
279				if o.name == "lsh" || o.name == "rsh" {
280					continue
281				}
282				fd := cfncData{s.name, o.name, s.name, o.symbol, "", "", "", ""}
283				for _, i := range s.i {
284					fd.Number = fmt.Sprintf("%d", i)
285					fd.FNumber = strings.Replace(fd.Number, "-", "Neg", -1)
286					for _, j := range s.i {
287						if o.name != "mod" && o.name != "div" || j != 0 {
288							fd.Ans = ansS(i, j, s.name, o.symbol)
289							fd.Input = fmt.Sprintf("%d", j)
290							if err := vrf1.Execute(w, fd); err != nil {
291								panic(err)
292							}
293						}
294
295						if o.name != "mod" && o.name != "div" || i != 0 {
296							fd.Ans = ansS(j, i, s.name, o.symbol)
297							fd.Input = fmt.Sprintf("%d", j)
298							if err := vrf2.Execute(w, fd); err != nil {
299								panic(err)
300							}
301						}
302
303					}
304				}
305
306			}
307		}
308
309		fmt.Fprintf(w, "}\n\n")
310	}
311
312	fmt.Fprint(w, `
313
314// TestArithmeticConst tests results for arithmetic operations against constants.
315func TestArithmeticConst(t *testing.T) {
316`)
317
318	for _, s := range szs {
319		fmt.Fprintf(w, `for _, test := range tests_%s%s {`, s.name, s.oponly)
320		// Use WriteString here to avoid a vet warning about formatting directives.
321		w.WriteString(`if got := test.fn(test.in); got != test.want {
322			t.Errorf("%s(%d) = %d, want %d\n", test.fnname, test.in, got, test.want)
323		}
324	}
325`)
326	}
327
328	fmt.Fprint(w, `
329}
330`)
331
332	// gofmt result
333	b := w.Bytes()
334	src, err := format.Source(b)
335	if err != nil {
336		fmt.Printf("%s\n", b)
337		panic(err)
338	}
339
340	// write to file
341	err = os.WriteFile("../arithConst_test.go", src, 0666)
342	if err != nil {
343		log.Fatalf("can't write output: %v\n", err)
344	}
345}
346