1# coding=utf-8 2# 3# Copyright © 2011, 2018 Intel Corporation 4# 5# Permission is hereby granted, free of charge, to any person obtaining a 6# copy of this software and associated documentation files (the "Software"), 7# to deal in the Software without restriction, including without limitation 8# the rights to use, copy, modify, merge, publish, distribute, sublicense, 9# and/or sell copies of the Software, and to permit persons to whom the 10# Software is furnished to do so, subject to the following conditions: 11# 12# The above copyright notice and this permission notice (including the next 13# paragraph) shall be included in all copies or substantial portions of the 14# Software. 15# 16# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 19# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 21# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 22# DEALINGS IN THE SOFTWARE. 23 24from sexps import * 25 26def make_test_case(f_name, ret_type, body): 27 """Create a simple optimization test case consisting of a single 28 function with the given name, return type, and body. 29 30 Global declarations are automatically created for any undeclared 31 variables that are referenced by the function. All undeclared 32 variables are assumed to be floats. 33 """ 34 check_sexp(body) 35 declarations = {} 36 def make_declarations(sexp, already_declared = ()): 37 if isinstance(sexp, list): 38 if len(sexp) == 2 and sexp[0] == 'var_ref': 39 if sexp[1] not in already_declared: 40 declarations[sexp[1]] = [ 41 'declare', ['in'], 'float', sexp[1]] 42 elif len(sexp) == 4 and sexp[0] == 'assign': 43 assert sexp[2][0] == 'var_ref' 44 if sexp[2][1] not in already_declared: 45 declarations[sexp[2][1]] = [ 46 'declare', ['out'], 'float', sexp[2][1]] 47 make_declarations(sexp[3], already_declared) 48 else: 49 already_declared = set(already_declared) 50 for s in sexp: 51 if isinstance(s, list) and len(s) >= 4 and \ 52 s[0] == 'declare': 53 already_declared.add(s[3]) 54 else: 55 make_declarations(s, already_declared) 56 make_declarations(body) 57 return list(declarations.values()) + \ 58 [['function', f_name, ['signature', ret_type, ['parameters'], body]]] 59 60 61# The following functions can be used to build expressions. 62 63def const_float(value): 64 """Create an expression representing the given floating point value.""" 65 return ['constant', 'float', ['{0:.6f}'.format(value)]] 66 67def const_bool(value): 68 """Create an expression representing the given boolean value. 69 70 If value is not a boolean, it is converted to a boolean. So, for 71 instance, const_bool(1) is equivalent to const_bool(True). 72 """ 73 return ['constant', 'bool', ['{0}'.format(1 if value else 0)]] 74 75def gt_zero(var_name): 76 """Create Construct the expression var_name > 0""" 77 return ['expression', 'bool', '<', const_float(0), ['var_ref', var_name]] 78 79 80# The following functions can be used to build complex control flow 81# statements. All of these functions return statement lists (even 82# those which only create a single statement), so that statements can 83# be sequenced together using the '+' operator. 84 85def return_(value = None): 86 """Create a return statement.""" 87 if value is not None: 88 return [['return', value]] 89 else: 90 return [['return']] 91 92def break_(): 93 """Create a break statement.""" 94 return ['break'] 95 96def continue_(): 97 """Create a continue statement.""" 98 return ['continue'] 99 100def simple_if(var_name, then_statements, else_statements = None): 101 """Create a statement of the form 102 103 if (var_name > 0.0) { 104 <then_statements> 105 } else { 106 <else_statements> 107 } 108 109 else_statements may be omitted. 110 """ 111 if else_statements is None: 112 else_statements = [] 113 check_sexp(then_statements) 114 check_sexp(else_statements) 115 return [['if', gt_zero(var_name), then_statements, else_statements]] 116 117def loop(statements): 118 """Create a loop containing the given statements as its loop 119 body. 120 """ 121 check_sexp(statements) 122 return [['loop', statements]] 123 124def declare_temp(var_type, var_name): 125 """Create a declaration of the form 126 127 (declare (temporary) <var_type> <var_name) 128 """ 129 return [['declare', ['temporary'], var_type, var_name]] 130 131def assign_x(var_name, value): 132 """Create a statement that assigns <value> to the variable 133 <var_name>. The assignment uses the mask (x). 134 """ 135 check_sexp(value) 136 return [['assign', ['x'], ['var_ref', var_name], value]] 137 138def complex_if(var_prefix, statements): 139 """Create a statement of the form 140 141 if (<var_prefix>a > 0.0) { 142 if (<var_prefix>b > 0.0) { 143 <statements> 144 } 145 } 146 147 This is useful in testing jump lowering, because if <statements> 148 ends in a jump, lower_jumps.cpp won't try to combine this 149 construct with the code that follows it, as it might do for a 150 simple if. 151 152 All variables used in the if statement are prefixed with 153 var_prefix. This can be used to ensure uniqueness. 154 """ 155 check_sexp(statements) 156 return simple_if(var_prefix + 'a', simple_if(var_prefix + 'b', statements)) 157 158def declare_execute_flag(): 159 """Create the statements that lower_jumps.cpp uses to declare and 160 initialize the temporary boolean execute_flag. 161 """ 162 return declare_temp('bool', 'execute_flag') + \ 163 assign_x('execute_flag', const_bool(True)) 164 165def declare_return_flag(): 166 """Create the statements that lower_jumps.cpp uses to declare and 167 initialize the temporary boolean return_flag. 168 """ 169 return declare_temp('bool', 'return_flag') + \ 170 assign_x('return_flag', const_bool(False)) 171 172def declare_return_value(): 173 """Create the statements that lower_jumps.cpp uses to declare and 174 initialize the temporary variable return_value. Assume that 175 return_value is a float. 176 """ 177 return declare_temp('float', 'return_value') 178 179def declare_break_flag(): 180 """Create the statements that lower_jumps.cpp uses to declare and 181 initialize the temporary boolean break_flag. 182 """ 183 return declare_temp('bool', 'break_flag') + \ 184 assign_x('break_flag', const_bool(False)) 185 186def lowered_return_simple(value = None): 187 """Create the statements that lower_jumps.cpp lowers a return 188 statement to, in situations where it does not need to clear the 189 execute flag. 190 """ 191 if value: 192 result = assign_x('return_value', value) 193 else: 194 result = [] 195 return result + assign_x('return_flag', const_bool(True)) 196 197def lowered_return(value = None): 198 """Create the statements that lower_jumps.cpp lowers a return 199 statement to, in situations where it needs to clear the execute 200 flag. 201 """ 202 return lowered_return_simple(value) + \ 203 assign_x('execute_flag', const_bool(False)) 204 205def lowered_continue(): 206 """Create the statement that lower_jumps.cpp lowers a continue 207 statement to. 208 """ 209 return assign_x('execute_flag', const_bool(False)) 210 211def lowered_break_simple(): 212 """Create the statement that lower_jumps.cpp lowers a break 213 statement to, in situations where it does not need to clear the 214 execute flag. 215 """ 216 return assign_x('break_flag', const_bool(True)) 217 218def lowered_break(): 219 """Create the statement that lower_jumps.cpp lowers a break 220 statement to, in situations where it needs to clear the execute 221 flag. 222 """ 223 return lowered_break_simple() + assign_x('execute_flag', const_bool(False)) 224 225def if_execute_flag(statements): 226 """Wrap statements in an if test so that they will only execute if 227 execute_flag is True. 228 """ 229 check_sexp(statements) 230 return [['if', ['var_ref', 'execute_flag'], statements, []]] 231 232def if_return_flag(then_statements, else_statements): 233 """Wrap statements in an if test with return_flag as the condition. 234 """ 235 check_sexp(then_statements) 236 check_sexp(else_statements) 237 return [['if', ['var_ref', 'return_flag'], then_statements, else_statements]] 238 239def if_not_return_flag(statements): 240 """Wrap statements in an if test so that they will only execute if 241 return_flag is False. 242 """ 243 check_sexp(statements) 244 return [['if', ['var_ref', 'return_flag'], [], statements]] 245 246def final_return(): 247 """Create the return statement that lower_jumps.cpp places at the 248 end of a function when lowering returns. 249 """ 250 return [['return', ['var_ref', 'return_value']]] 251 252def final_break(): 253 """Create the conditional break statement that lower_jumps.cpp 254 places at the end of a function when lowering breaks. 255 """ 256 return [['if', ['var_ref', 'break_flag'], break_(), []]] 257 258def bash_quote(*args): 259 """Quote the arguments appropriately so that bash will understand 260 each argument as a single word. 261 """ 262 def quote_word(word): 263 for c in word: 264 if not (c.isalpha() or c.isdigit() or c in '@%_-+=:,./'): 265 break 266 else: 267 if not word: 268 return "''" 269 return word 270 return "'{0}'".format(word.replace("'", "'\"'\"'")) 271 return ' '.join(quote_word(word) for word in args) 272 273def create_test_case(input_sexp, expected_sexp, test_name, 274 pull_out_jumps=False, lower_sub_return=False, 275 lower_main_return=False, lower_continue=False): 276 """Create a test case that verifies that do_lower_jumps transforms 277 the given code in the expected way. 278 """ 279 check_sexp(input_sexp) 280 check_sexp(expected_sexp) 281 input_str = sexp_to_string(sort_decls(input_sexp)) 282 expected_output = sexp_to_string(sort_decls(expected_sexp)) # XXX: don't stringify this 283 optimization = ( 284 'do_lower_jumps({0:d}, {1:d}, {2:d}, {3:d})'.format( 285 pull_out_jumps, lower_sub_return, lower_main_return, 286 lower_continue)) 287 288 return (test_name, optimization, input_str, expected_output) 289 290def test_lower_returns_main(): 291 """Test that do_lower_jumps respects the lower_main_return flag in deciding 292 whether to lower returns in the main function. 293 """ 294 input_sexp = make_test_case('main', 'void', ( 295 complex_if('', return_()) 296 )) 297 expected_sexp = make_test_case('main', 'void', ( 298 declare_execute_flag() + 299 declare_return_flag() + 300 complex_if('', lowered_return()) 301 )) 302 yield create_test_case( 303 input_sexp, expected_sexp, 'lower_returns_main_true', 304 lower_main_return=True) 305 yield create_test_case( 306 input_sexp, input_sexp, 'lower_returns_main_false', 307 lower_main_return=False) 308 309def test_lower_returns_sub(): 310 """Test that do_lower_jumps respects the lower_sub_return flag in deciding 311 whether to lower returns in subroutines. 312 """ 313 input_sexp = make_test_case('sub', 'void', ( 314 complex_if('', return_()) 315 )) 316 expected_sexp = make_test_case('sub', 'void', ( 317 declare_execute_flag() + 318 declare_return_flag() + 319 complex_if('', lowered_return()) 320 )) 321 yield create_test_case( 322 input_sexp, expected_sexp, 'lower_returns_sub_true', 323 lower_sub_return=True) 324 yield create_test_case( 325 input_sexp, input_sexp, 'lower_returns_sub_false', 326 lower_sub_return=False) 327 328def test_lower_returns_1(): 329 """Test that a void return at the end of a function is eliminated.""" 330 input_sexp = make_test_case('main', 'void', ( 331 assign_x('a', const_float(1)) + 332 return_() 333 )) 334 expected_sexp = make_test_case('main', 'void', ( 335 assign_x('a', const_float(1)) 336 )) 337 yield create_test_case( 338 input_sexp, expected_sexp, 'lower_returns_1', lower_main_return=True) 339 340def test_lower_returns_2(): 341 """Test that lowering is not performed on a non-void return at the end of 342 subroutine. 343 """ 344 input_sexp = make_test_case('sub', 'float', ( 345 assign_x('a', const_float(1)) + 346 return_(const_float(1)) 347 )) 348 yield create_test_case( 349 input_sexp, input_sexp, 'lower_returns_2', lower_sub_return=True) 350 351def test_lower_returns_3(): 352 """Test lowering of returns when there is one nested inside a complex 353 structure of ifs, and one at the end of a function. 354 355 In this case, the latter return needs to be lowered because it will not be 356 at the end of the function once the final return is inserted. 357 """ 358 input_sexp = make_test_case('sub', 'float', ( 359 complex_if('', return_(const_float(1))) + 360 return_(const_float(2)) 361 )) 362 expected_sexp = make_test_case('sub', 'float', ( 363 declare_execute_flag() + 364 declare_return_value() + 365 declare_return_flag() + 366 complex_if('', lowered_return(const_float(1))) + 367 if_execute_flag(lowered_return(const_float(2))) + 368 final_return() 369 )) 370 yield create_test_case( 371 input_sexp, expected_sexp, 'lower_returns_3', lower_sub_return=True) 372 373def test_lower_returns_4(): 374 """Test that returns are properly lowered when they occur in both branches 375 of an if-statement. 376 """ 377 input_sexp = make_test_case('sub', 'float', ( 378 simple_if('a', return_(const_float(1)), 379 return_(const_float(2))) 380 )) 381 expected_sexp = make_test_case('sub', 'float', ( 382 declare_execute_flag() + 383 declare_return_value() + 384 declare_return_flag() + 385 simple_if('a', lowered_return(const_float(1)), 386 lowered_return(const_float(2))) + 387 final_return() 388 )) 389 yield create_test_case( 390 input_sexp, expected_sexp, 'lower_returns_4', lower_sub_return=True) 391 392def test_lower_unified_returns(): 393 """If both branches of an if statement end in a return, and pull_out_jumps 394 is True, then those returns should be lifted outside the if and then 395 properly lowered. 396 397 Verify that this lowering occurs during the same pass as the lowering of 398 other returns by checking that extra temporary variables aren't generated. 399 """ 400 input_sexp = make_test_case('main', 'void', ( 401 complex_if('a', return_()) + 402 simple_if('b', simple_if('c', return_(), return_())) 403 )) 404 expected_sexp = make_test_case('main', 'void', ( 405 declare_execute_flag() + 406 declare_return_flag() + 407 complex_if('a', lowered_return()) + 408 if_execute_flag(simple_if('b', (simple_if('c', [], []) + 409 lowered_return()))) 410 )) 411 yield create_test_case( 412 input_sexp, expected_sexp, 'lower_unified_returns', 413 lower_main_return=True, pull_out_jumps=True) 414 415def test_lower_pulled_out_jump(): 416 doc_string = """If one branch of an if ends in a jump, and control cannot 417 fall out the bottom of the other branch, and pull_out_jumps is 418 True, then the jump is lifted outside the if. 419 420 Verify that this lowering occurs during the same pass as the 421 lowering of other jumps by checking that extra temporary 422 variables aren't generated. 423 """ 424 input_sexp = make_test_case('main', 'void', ( 425 complex_if('a', return_()) + 426 loop(simple_if('b', simple_if('c', break_(), continue_()), 427 return_())) + 428 assign_x('d', const_float(1)) 429 )) 430 # Note: optimization produces two other effects: the break 431 # gets lifted out of the if statements, and the code after the 432 # loop gets guarded so that it only executes if the return 433 # flag is clear. 434 expected_sexp = make_test_case('main', 'void', ( 435 declare_execute_flag() + 436 declare_return_flag() + 437 complex_if('a', lowered_return()) + 438 if_execute_flag( 439 loop(simple_if('b', simple_if('c', [], continue_()), 440 lowered_return_simple()) + 441 break_()) + 442 443 if_return_flag(assign_x('return_flag', const_bool(1)) + 444 assign_x('execute_flag', const_bool(0)), 445 assign_x('d', const_float(1)))) 446 )) 447 yield create_test_case( 448 input_sexp, expected_sexp, 'lower_pulled_out_jump', 449 lower_main_return=True, pull_out_jumps=True) 450 451 452def test_remove_continue_at_end_of_loop(): 453 """Test that a redundant continue-statement at the end of a loop is 454 removed. 455 """ 456 input_sexp = make_test_case('main', 'void', ( 457 loop(assign_x('a', const_float(1)) + 458 continue_()) 459 )) 460 expected_sexp = make_test_case('main', 'void', ( 461 loop(assign_x('a', const_float(1))) 462 )) 463 yield create_test_case(input_sexp, expected_sexp, 'remove_continue_at_end_of_loop') 464 465def test_lower_return_void_at_end_of_loop(): 466 """Test that a return of void at the end of a loop is properly lowered.""" 467 input_sexp = make_test_case('main', 'void', ( 468 loop(assign_x('a', const_float(1)) + 469 return_()) + 470 assign_x('b', const_float(2)) 471 )) 472 expected_sexp = make_test_case('main', 'void', ( 473 declare_execute_flag() + 474 declare_return_flag() + 475 loop(assign_x('a', const_float(1)) + 476 lowered_return_simple() + 477 break_()) + 478 if_return_flag(assign_x('return_flag', const_bool(1)) + 479 assign_x('execute_flag', const_bool(0)), 480 assign_x('b', const_float(2))) 481 )) 482 yield create_test_case( 483 input_sexp, input_sexp, 'return_void_at_end_of_loop_lower_nothing') 484 yield create_test_case( 485 input_sexp, expected_sexp, 'return_void_at_end_of_loop_lower_return', 486 lower_main_return=True) 487 488 489def test_lower_return_non_void_at_end_of_loop(): 490 """Test that a non-void return at the end of a loop is properly lowered.""" 491 input_sexp = make_test_case('sub', 'float', ( 492 loop(assign_x('a', const_float(1)) + 493 return_(const_float(2))) + 494 assign_x('b', const_float(3)) + 495 return_(const_float(4)) 496 )) 497 expected_sexp = make_test_case('sub', 'float', ( 498 declare_execute_flag() + 499 declare_return_value() + 500 declare_return_flag() + 501 loop(assign_x('a', const_float(1)) + 502 lowered_return_simple(const_float(2)) + 503 break_()) + 504 if_return_flag(assign_x('return_value', '(var_ref return_value)') + 505 assign_x('return_flag', const_bool(1)) + 506 assign_x('execute_flag', const_bool(0)), 507 assign_x('b', const_float(3)) + 508 lowered_return(const_float(4))) + 509 final_return() 510 )) 511 yield create_test_case( 512 input_sexp, input_sexp, 'return_non_void_at_end_of_loop_lower_nothing') 513 yield create_test_case( 514 input_sexp, expected_sexp, 515 'return_non_void_at_end_of_loop_lower_return', lower_sub_return=True) 516 517 518CASES = [ 519 test_lower_pulled_out_jump, 520 test_lower_return_non_void_at_end_of_loop, 521 test_lower_return_void_at_end_of_loop, 522 test_lower_returns_1, test_lower_returns_2, test_lower_returns_3, 523 test_lower_returns_4, test_lower_returns_main, test_lower_returns_sub, 524 test_lower_unified_returns, test_remove_continue_at_end_of_loop, 525] 526