xref: /aosp_15_r20/external/mesa3d/src/compiler/glsl/tests/lower_jump_cases.py (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1# coding=utf-8
2#
3# Copyright © 2011, 2018 Intel Corporation
4#
5# Permission is hereby granted, free of charge, to any person obtaining a
6# copy of this software and associated documentation files (the "Software"),
7# to deal in the Software without restriction, including without limitation
8# the rights to use, copy, modify, merge, publish, distribute, sublicense,
9# and/or sell copies of the Software, and to permit persons to whom the
10# Software is furnished to do so, subject to the following conditions:
11#
12# The above copyright notice and this permission notice (including the next
13# paragraph) shall be included in all copies or substantial portions of the
14# Software.
15#
16# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
19# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
21# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
22# DEALINGS IN THE SOFTWARE.
23
24from sexps import *
25
26def make_test_case(f_name, ret_type, body):
27    """Create a simple optimization test case consisting of a single
28    function with the given name, return type, and body.
29
30    Global declarations are automatically created for any undeclared
31    variables that are referenced by the function.  All undeclared
32    variables are assumed to be floats.
33    """
34    check_sexp(body)
35    declarations = {}
36    def make_declarations(sexp, already_declared = ()):
37        if isinstance(sexp, list):
38            if len(sexp) == 2 and sexp[0] == 'var_ref':
39                if sexp[1] not in already_declared:
40                    declarations[sexp[1]] = [
41                        'declare', ['in'], 'float', sexp[1]]
42            elif len(sexp) == 4 and sexp[0] == 'assign':
43                assert sexp[2][0] == 'var_ref'
44                if sexp[2][1] not in already_declared:
45                    declarations[sexp[2][1]] = [
46                        'declare', ['out'], 'float', sexp[2][1]]
47                make_declarations(sexp[3], already_declared)
48            else:
49                already_declared = set(already_declared)
50                for s in sexp:
51                    if isinstance(s, list) and len(s) >= 4 and \
52                            s[0] == 'declare':
53                        already_declared.add(s[3])
54                    else:
55                        make_declarations(s, already_declared)
56    make_declarations(body)
57    return list(declarations.values()) + \
58        [['function', f_name, ['signature', ret_type, ['parameters'], body]]]
59
60
61# The following functions can be used to build expressions.
62
63def const_float(value):
64    """Create an expression representing the given floating point value."""
65    return ['constant', 'float', ['{0:.6f}'.format(value)]]
66
67def const_bool(value):
68    """Create an expression representing the given boolean value.
69
70    If value is not a boolean, it is converted to a boolean.  So, for
71    instance, const_bool(1) is equivalent to const_bool(True).
72    """
73    return ['constant', 'bool', ['{0}'.format(1 if value else 0)]]
74
75def gt_zero(var_name):
76    """Create Construct the expression var_name > 0"""
77    return ['expression', 'bool', '<', const_float(0), ['var_ref', var_name]]
78
79
80# The following functions can be used to build complex control flow
81# statements.  All of these functions return statement lists (even
82# those which only create a single statement), so that statements can
83# be sequenced together using the '+' operator.
84
85def return_(value = None):
86    """Create a return statement."""
87    if value is not None:
88        return [['return', value]]
89    else:
90        return [['return']]
91
92def break_():
93    """Create a break statement."""
94    return ['break']
95
96def continue_():
97    """Create a continue statement."""
98    return ['continue']
99
100def simple_if(var_name, then_statements, else_statements = None):
101    """Create a statement of the form
102
103    if (var_name > 0.0) {
104       <then_statements>
105    } else {
106       <else_statements>
107    }
108
109    else_statements may be omitted.
110    """
111    if else_statements is None:
112        else_statements = []
113    check_sexp(then_statements)
114    check_sexp(else_statements)
115    return [['if', gt_zero(var_name), then_statements, else_statements]]
116
117def loop(statements):
118    """Create a loop containing the given statements as its loop
119    body.
120    """
121    check_sexp(statements)
122    return [['loop', statements]]
123
124def declare_temp(var_type, var_name):
125    """Create a declaration of the form
126
127    (declare (temporary) <var_type> <var_name)
128    """
129    return [['declare', ['temporary'], var_type, var_name]]
130
131def assign_x(var_name, value):
132    """Create a statement that assigns <value> to the variable
133    <var_name>.  The assignment uses the mask (x).
134    """
135    check_sexp(value)
136    return [['assign', ['x'], ['var_ref', var_name], value]]
137
138def complex_if(var_prefix, statements):
139    """Create a statement of the form
140
141    if (<var_prefix>a > 0.0) {
142       if (<var_prefix>b > 0.0) {
143          <statements>
144       }
145    }
146
147    This is useful in testing jump lowering, because if <statements>
148    ends in a jump, lower_jumps.cpp won't try to combine this
149    construct with the code that follows it, as it might do for a
150    simple if.
151
152    All variables used in the if statement are prefixed with
153    var_prefix.  This can be used to ensure uniqueness.
154    """
155    check_sexp(statements)
156    return simple_if(var_prefix + 'a', simple_if(var_prefix + 'b', statements))
157
158def declare_execute_flag():
159    """Create the statements that lower_jumps.cpp uses to declare and
160    initialize the temporary boolean execute_flag.
161    """
162    return declare_temp('bool', 'execute_flag') + \
163        assign_x('execute_flag', const_bool(True))
164
165def declare_return_flag():
166    """Create the statements that lower_jumps.cpp uses to declare and
167    initialize the temporary boolean return_flag.
168    """
169    return declare_temp('bool', 'return_flag') + \
170        assign_x('return_flag', const_bool(False))
171
172def declare_return_value():
173    """Create the statements that lower_jumps.cpp uses to declare and
174    initialize the temporary variable return_value.  Assume that
175    return_value is a float.
176    """
177    return declare_temp('float', 'return_value')
178
179def declare_break_flag():
180    """Create the statements that lower_jumps.cpp uses to declare and
181    initialize the temporary boolean break_flag.
182    """
183    return declare_temp('bool', 'break_flag') + \
184        assign_x('break_flag', const_bool(False))
185
186def lowered_return_simple(value = None):
187    """Create the statements that lower_jumps.cpp lowers a return
188    statement to, in situations where it does not need to clear the
189    execute flag.
190    """
191    if value:
192        result = assign_x('return_value', value)
193    else:
194        result = []
195    return result + assign_x('return_flag', const_bool(True))
196
197def lowered_return(value = None):
198    """Create the statements that lower_jumps.cpp lowers a return
199    statement to, in situations where it needs to clear the execute
200    flag.
201    """
202    return lowered_return_simple(value) + \
203        assign_x('execute_flag', const_bool(False))
204
205def lowered_continue():
206    """Create the statement that lower_jumps.cpp lowers a continue
207    statement to.
208    """
209    return assign_x('execute_flag', const_bool(False))
210
211def lowered_break_simple():
212    """Create the statement that lower_jumps.cpp lowers a break
213    statement to, in situations where it does not need to clear the
214    execute flag.
215    """
216    return assign_x('break_flag', const_bool(True))
217
218def lowered_break():
219    """Create the statement that lower_jumps.cpp lowers a break
220    statement to, in situations where it needs to clear the execute
221    flag.
222    """
223    return lowered_break_simple() + assign_x('execute_flag', const_bool(False))
224
225def if_execute_flag(statements):
226    """Wrap statements in an if test so that they will only execute if
227    execute_flag is True.
228    """
229    check_sexp(statements)
230    return [['if', ['var_ref', 'execute_flag'], statements, []]]
231
232def if_return_flag(then_statements, else_statements):
233    """Wrap statements in an if test with return_flag as the condition.
234    """
235    check_sexp(then_statements)
236    check_sexp(else_statements)
237    return [['if', ['var_ref', 'return_flag'], then_statements, else_statements]]
238
239def if_not_return_flag(statements):
240    """Wrap statements in an if test so that they will only execute if
241    return_flag is False.
242    """
243    check_sexp(statements)
244    return [['if', ['var_ref', 'return_flag'], [], statements]]
245
246def final_return():
247    """Create the return statement that lower_jumps.cpp places at the
248    end of a function when lowering returns.
249    """
250    return [['return', ['var_ref', 'return_value']]]
251
252def final_break():
253    """Create the conditional break statement that lower_jumps.cpp
254    places at the end of a function when lowering breaks.
255    """
256    return [['if', ['var_ref', 'break_flag'], break_(), []]]
257
258def bash_quote(*args):
259    """Quote the arguments appropriately so that bash will understand
260    each argument as a single word.
261    """
262    def quote_word(word):
263        for c in word:
264            if not (c.isalpha() or c.isdigit() or c in '@%_-+=:,./'):
265                break
266        else:
267            if not word:
268                return "''"
269            return word
270        return "'{0}'".format(word.replace("'", "'\"'\"'"))
271    return ' '.join(quote_word(word) for word in args)
272
273def create_test_case(input_sexp, expected_sexp, test_name,
274                     pull_out_jumps=False, lower_sub_return=False,
275                     lower_main_return=False, lower_continue=False):
276    """Create a test case that verifies that do_lower_jumps transforms
277    the given code in the expected way.
278    """
279    check_sexp(input_sexp)
280    check_sexp(expected_sexp)
281    input_str = sexp_to_string(sort_decls(input_sexp))
282    expected_output = sexp_to_string(sort_decls(expected_sexp)) # XXX: don't stringify this
283    optimization = (
284        'do_lower_jumps({0:d}, {1:d}, {2:d}, {3:d})'.format(
285            pull_out_jumps, lower_sub_return, lower_main_return,
286            lower_continue))
287
288    return (test_name, optimization, input_str, expected_output)
289
290def test_lower_returns_main():
291    """Test that do_lower_jumps respects the lower_main_return flag in deciding
292    whether to lower returns in the main function.
293    """
294    input_sexp = make_test_case('main', 'void', (
295            complex_if('', return_())
296            ))
297    expected_sexp = make_test_case('main', 'void', (
298            declare_execute_flag() +
299            declare_return_flag() +
300            complex_if('', lowered_return())
301            ))
302    yield create_test_case(
303        input_sexp, expected_sexp, 'lower_returns_main_true',
304        lower_main_return=True)
305    yield create_test_case(
306        input_sexp, input_sexp, 'lower_returns_main_false',
307        lower_main_return=False)
308
309def test_lower_returns_sub():
310    """Test that do_lower_jumps respects the lower_sub_return flag in deciding
311    whether to lower returns in subroutines.
312    """
313    input_sexp = make_test_case('sub', 'void', (
314            complex_if('', return_())
315            ))
316    expected_sexp = make_test_case('sub', 'void', (
317            declare_execute_flag() +
318            declare_return_flag() +
319            complex_if('', lowered_return())
320            ))
321    yield create_test_case(
322        input_sexp, expected_sexp, 'lower_returns_sub_true',
323        lower_sub_return=True)
324    yield create_test_case(
325        input_sexp, input_sexp, 'lower_returns_sub_false',
326        lower_sub_return=False)
327
328def test_lower_returns_1():
329    """Test that a void return at the end of a function is eliminated."""
330    input_sexp = make_test_case('main', 'void', (
331            assign_x('a', const_float(1)) +
332            return_()
333            ))
334    expected_sexp = make_test_case('main', 'void', (
335            assign_x('a', const_float(1))
336            ))
337    yield create_test_case(
338        input_sexp, expected_sexp, 'lower_returns_1', lower_main_return=True)
339
340def test_lower_returns_2():
341    """Test that lowering is not performed on a non-void return at the end of
342    subroutine.
343    """
344    input_sexp = make_test_case('sub', 'float', (
345            assign_x('a', const_float(1)) +
346            return_(const_float(1))
347            ))
348    yield create_test_case(
349        input_sexp, input_sexp, 'lower_returns_2', lower_sub_return=True)
350
351def test_lower_returns_3():
352    """Test lowering of returns when there is one nested inside a complex
353    structure of ifs, and one at the end of a function.
354
355    In this case, the latter return needs to be lowered because it will not be
356    at the end of the function once the final return is inserted.
357    """
358    input_sexp = make_test_case('sub', 'float', (
359            complex_if('', return_(const_float(1))) +
360            return_(const_float(2))
361            ))
362    expected_sexp = make_test_case('sub', 'float', (
363            declare_execute_flag() +
364            declare_return_value() +
365            declare_return_flag() +
366            complex_if('', lowered_return(const_float(1))) +
367            if_execute_flag(lowered_return(const_float(2))) +
368            final_return()
369            ))
370    yield create_test_case(
371        input_sexp, expected_sexp, 'lower_returns_3', lower_sub_return=True)
372
373def test_lower_returns_4():
374    """Test that returns are properly lowered when they occur in both branches
375    of an if-statement.
376    """
377    input_sexp = make_test_case('sub', 'float', (
378            simple_if('a', return_(const_float(1)),
379                      return_(const_float(2)))
380            ))
381    expected_sexp = make_test_case('sub', 'float', (
382            declare_execute_flag() +
383            declare_return_value() +
384            declare_return_flag() +
385            simple_if('a', lowered_return(const_float(1)),
386                      lowered_return(const_float(2))) +
387            final_return()
388            ))
389    yield create_test_case(
390        input_sexp, expected_sexp, 'lower_returns_4', lower_sub_return=True)
391
392def test_lower_unified_returns():
393    """If both branches of an if statement end in a return, and pull_out_jumps
394    is True, then those returns should be lifted outside the if and then
395    properly lowered.
396
397    Verify that this lowering occurs during the same pass as the lowering of
398    other returns by checking that extra temporary variables aren't generated.
399    """
400    input_sexp = make_test_case('main', 'void', (
401            complex_if('a', return_()) +
402            simple_if('b', simple_if('c', return_(), return_()))
403            ))
404    expected_sexp = make_test_case('main', 'void', (
405            declare_execute_flag() +
406            declare_return_flag() +
407            complex_if('a', lowered_return()) +
408            if_execute_flag(simple_if('b', (simple_if('c', [], []) +
409                                            lowered_return())))
410            ))
411    yield create_test_case(
412        input_sexp, expected_sexp, 'lower_unified_returns',
413        lower_main_return=True, pull_out_jumps=True)
414
415def test_lower_pulled_out_jump():
416    doc_string = """If one branch of an if ends in a jump, and control cannot
417    fall out the bottom of the other branch, and pull_out_jumps is
418    True, then the jump is lifted outside the if.
419
420    Verify that this lowering occurs during the same pass as the
421    lowering of other jumps by checking that extra temporary
422    variables aren't generated.
423    """
424    input_sexp = make_test_case('main', 'void', (
425            complex_if('a', return_()) +
426            loop(simple_if('b', simple_if('c', break_(), continue_()),
427                           return_())) +
428            assign_x('d', const_float(1))
429            ))
430    # Note: optimization produces two other effects: the break
431    # gets lifted out of the if statements, and the code after the
432    # loop gets guarded so that it only executes if the return
433    # flag is clear.
434    expected_sexp = make_test_case('main', 'void', (
435            declare_execute_flag() +
436            declare_return_flag() +
437            complex_if('a', lowered_return()) +
438            if_execute_flag(
439                loop(simple_if('b', simple_if('c', [], continue_()),
440                               lowered_return_simple()) +
441                     break_()) +
442
443                if_return_flag(assign_x('return_flag', const_bool(1)) +
444                               assign_x('execute_flag', const_bool(0)),
445                               assign_x('d', const_float(1))))
446            ))
447    yield create_test_case(
448        input_sexp, expected_sexp, 'lower_pulled_out_jump',
449        lower_main_return=True, pull_out_jumps=True)
450
451
452def test_remove_continue_at_end_of_loop():
453    """Test that a redundant continue-statement at the end of a loop is
454    removed.
455    """
456    input_sexp = make_test_case('main', 'void', (
457            loop(assign_x('a', const_float(1)) +
458                 continue_())
459            ))
460    expected_sexp = make_test_case('main', 'void', (
461            loop(assign_x('a', const_float(1)))
462            ))
463    yield create_test_case(input_sexp, expected_sexp, 'remove_continue_at_end_of_loop')
464
465def test_lower_return_void_at_end_of_loop():
466    """Test that a return of void at the end of a loop is properly lowered."""
467    input_sexp = make_test_case('main', 'void', (
468            loop(assign_x('a', const_float(1)) +
469                 return_()) +
470            assign_x('b', const_float(2))
471            ))
472    expected_sexp = make_test_case('main', 'void', (
473            declare_execute_flag() +
474            declare_return_flag() +
475            loop(assign_x('a', const_float(1)) +
476                 lowered_return_simple() +
477                 break_()) +
478            if_return_flag(assign_x('return_flag', const_bool(1)) +
479                           assign_x('execute_flag', const_bool(0)),
480                           assign_x('b', const_float(2)))
481            ))
482    yield create_test_case(
483        input_sexp, input_sexp, 'return_void_at_end_of_loop_lower_nothing')
484    yield create_test_case(
485        input_sexp, expected_sexp, 'return_void_at_end_of_loop_lower_return',
486        lower_main_return=True)
487
488
489def test_lower_return_non_void_at_end_of_loop():
490    """Test that a non-void return at the end of a loop is properly lowered."""
491    input_sexp = make_test_case('sub', 'float', (
492            loop(assign_x('a', const_float(1)) +
493                 return_(const_float(2))) +
494            assign_x('b', const_float(3)) +
495            return_(const_float(4))
496            ))
497    expected_sexp = make_test_case('sub', 'float', (
498            declare_execute_flag() +
499            declare_return_value() +
500            declare_return_flag() +
501            loop(assign_x('a', const_float(1)) +
502                 lowered_return_simple(const_float(2)) +
503                 break_()) +
504            if_return_flag(assign_x('return_value', '(var_ref return_value)') +
505                           assign_x('return_flag', const_bool(1)) +
506                           assign_x('execute_flag', const_bool(0)),
507                           assign_x('b', const_float(3)) +
508                               lowered_return(const_float(4))) +
509            final_return()
510            ))
511    yield create_test_case(
512        input_sexp, input_sexp, 'return_non_void_at_end_of_loop_lower_nothing')
513    yield create_test_case(
514        input_sexp, expected_sexp,
515        'return_non_void_at_end_of_loop_lower_return', lower_sub_return=True)
516
517
518CASES = [
519    test_lower_pulled_out_jump,
520    test_lower_return_non_void_at_end_of_loop,
521    test_lower_return_void_at_end_of_loop,
522    test_lower_returns_1, test_lower_returns_2, test_lower_returns_3,
523    test_lower_returns_4, test_lower_returns_main, test_lower_returns_sub,
524    test_lower_unified_returns, test_remove_continue_at_end_of_loop,
525]
526