1// run -goexperiment rangefunc
2
3// Copyright 2023 The Go Authors. All rights reserved.
4// Use of this source code is governed by a BSD-style
5// license that can be found in the LICENSE file.
6
7// Test the 'for range' construct ranging over functions.
8
9package main
10
11var gj int
12
13func yield4x(yield func() bool) {
14	_ = yield() && yield() && yield() && yield()
15}
16
17func yield4(yield func(int) bool) {
18	_ = yield(1) && yield(2) && yield(3) && yield(4)
19}
20
21func yield3(yield func(int) bool) {
22	_ = yield(1) && yield(2) && yield(3)
23}
24
25func yield2(yield func(int) bool) {
26	_ = yield(1) && yield(2)
27}
28
29func testfunc0() {
30	j := 0
31	for range yield4x {
32		j++
33	}
34	if j != 4 {
35		println("wrong count ranging over yield4x:", j)
36		panic("testfunc0")
37	}
38
39	j = 0
40	for _ = range yield4 {
41		j++
42	}
43	if j != 4 {
44		println("wrong count ranging over yield4:", j)
45		panic("testfunc0")
46	}
47}
48
49func testfunc1() {
50	bad := false
51	j := 1
52	for i := range yield4 {
53		if i != j {
54			println("range var", i, "want", j)
55			bad = true
56		}
57		j++
58	}
59	if j != 5 {
60		println("wrong count ranging over f:", j)
61		bad = true
62	}
63	if bad {
64		panic("testfunc1")
65	}
66}
67
68func testfunc2() {
69	bad := false
70	j := 1
71	var i int
72	for i = range yield4 {
73		if i != j {
74			println("range var", i, "want", j)
75			bad = true
76		}
77		j++
78	}
79	if j != 5 {
80		println("wrong count ranging over f:", j)
81		bad = true
82	}
83	if i != 4 {
84		println("wrong final i ranging over f:", i)
85		bad = true
86	}
87	if bad {
88		panic("testfunc2")
89	}
90}
91
92func testfunc3() {
93	bad := false
94	j := 1
95	var i int
96	for i = range yield4 {
97		if i != j {
98			println("range var", i, "want", j)
99			bad = true
100		}
101		j++
102		if i == 2 {
103			break
104		}
105		continue
106	}
107	if j != 3 {
108		println("wrong count ranging over f:", j)
109		bad = true
110	}
111	if i != 2 {
112		println("wrong final i ranging over f:", i)
113		bad = true
114	}
115	if bad {
116		panic("testfunc3")
117	}
118}
119
120func testfunc4() {
121	bad := false
122	j := 1
123	var i int
124	func() {
125		for i = range yield4 {
126			if i != j {
127				println("range var", i, "want", j)
128				bad = true
129			}
130			j++
131			if i == 2 {
132				return
133			}
134		}
135	}()
136	if j != 3 {
137		println("wrong count ranging over f:", j)
138		bad = true
139	}
140	if i != 2 {
141		println("wrong final i ranging over f:", i)
142		bad = true
143	}
144	if bad {
145		panic("testfunc3")
146	}
147}
148
149func func5() (int, int) {
150	for i := range yield4 {
151		return 10, i
152	}
153	panic("still here")
154}
155
156func testfunc5() {
157	x, y := func5()
158	if x != 10 || y != 1 {
159		println("wrong results", x, y, "want", 10, 1)
160		panic("testfunc5")
161	}
162}
163
164func func6() (z, w int) {
165	for i := range yield4 {
166		z = 10
167		w = i
168		return
169	}
170	panic("still here")
171}
172
173func testfunc6() {
174	x, y := func6()
175	if x != 10 || y != 1 {
176		println("wrong results", x, y, "want", 10, 1)
177		panic("testfunc6")
178	}
179}
180
181var saved []int
182
183func save(x int) {
184	saved = append(saved, x)
185}
186
187func printslice(s []int) {
188	print("[")
189	for i, x := range s {
190		if i > 0 {
191			print(", ")
192		}
193		print(x)
194	}
195	print("]")
196}
197
198func eqslice(s, t []int) bool {
199	if len(s) != len(t) {
200		return false
201	}
202	for i, x := range s {
203		if x != t[i] {
204			return false
205		}
206	}
207	return true
208}
209
210func func7() {
211	defer save(-1)
212	for i := range yield4 {
213		defer save(i)
214	}
215	defer save(5)
216}
217
218func checkslice(name string, saved, want []int) {
219	if !eqslice(saved, want) {
220		print("wrong results ")
221		printslice(saved)
222		print(" want ")
223		printslice(want)
224		print("\n")
225		panic(name)
226	}
227}
228
229func testfunc7() {
230	saved = nil
231	func7()
232	want := []int{5, 4, 3, 2, 1, -1}
233	checkslice("testfunc7", saved, want)
234}
235
236func func8() {
237	defer save(-1)
238	for i := range yield2 {
239		for j := range yield3 {
240			defer save(i*10 + j)
241		}
242		defer save(i)
243	}
244	defer save(-2)
245	for i := range yield4 {
246		defer save(i)
247	}
248	defer save(-3)
249}
250
251func testfunc8() {
252	saved = nil
253	func8()
254	want := []int{-3, 4, 3, 2, 1, -2, 2, 23, 22, 21, 1, 13, 12, 11, -1}
255	checkslice("testfunc8", saved, want)
256}
257
258func func9() {
259	n := 0
260	for _ = range yield2 {
261		for _ = range yield3 {
262			n++
263			defer save(n)
264		}
265	}
266}
267
268func testfunc9() {
269	saved = nil
270	func9()
271	want := []int{6, 5, 4, 3, 2, 1}
272	checkslice("testfunc9", saved, want)
273}
274
275// test that range evaluates the index and value expressions
276// exactly once per iteration.
277
278var ncalls = 0
279
280func getvar(p *int) *int {
281	ncalls++
282	return p
283}
284
285func iter2(list ...int) func(func(int, int) bool) {
286	return func(yield func(int, int) bool) {
287		for i, x := range list {
288			if !yield(i, x) {
289				return
290			}
291		}
292	}
293}
294
295func testcalls() {
296	var i, v int
297	ncalls = 0
298	si := 0
299	sv := 0
300	for *getvar(&i), *getvar(&v) = range iter2(1, 2) {
301		si += i
302		sv += v
303	}
304	if ncalls != 4 {
305		println("wrong number of calls:", ncalls, "!= 4")
306		panic("fail")
307	}
308	if si != 1 || sv != 3 {
309		println("wrong sum in testcalls", si, sv)
310		panic("fail")
311	}
312}
313
314type iter3YieldFunc func(int, int) bool
315
316func iter3(list ...int) func(iter3YieldFunc) {
317	return func(yield iter3YieldFunc) {
318		for k, v := range list {
319			if !yield(k, v) {
320				return
321			}
322		}
323	}
324}
325
326func testcalls1() {
327	ncalls := 0
328	for k, v := range iter3(1, 2, 3) {
329		_, _ = k, v
330		ncalls++
331	}
332	if ncalls != 3 {
333		println("wrong number of calls:", ncalls, "!= 3")
334		panic("fail")
335	}
336}
337
338func main() {
339	testfunc0()
340	testfunc1()
341	testfunc2()
342	testfunc3()
343	testfunc4()
344	testfunc5()
345	testfunc6()
346	testfunc7()
347	testfunc8()
348	testfunc9()
349	testcalls()
350	testcalls1()
351}
352