1# mypy: allow-untyped-defs 2import contextlib 3import dataclasses 4import functools 5import itertools 6import math 7import re 8import sys 9import warnings 10from copy import copy, deepcopy 11from enum import Enum 12from typing import cast, Dict, List, Optional, Sequence, Set, Tuple, Union 13 14import sympy 15 16import torch 17import torch.fx 18from torch._inductor import dependencies 19from torch._prims_common import is_float_dtype, is_integer_dtype 20from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing 21from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT 22 23from ..._dynamo.utils import counters 24from .. import codecache, config, cpp_builder, cpu_vec_isa, ir, metrics 25from ..loop_body import LoopBody 26from ..scheduler import ( 27 BaseSchedulerNode, 28 BaseScheduling, 29 ForeachKernelSchedulerNode, 30 FusedSchedulerNode, 31 Scheduler, 32 SchedulerNode, 33) 34from ..utils import ( 35 cache_on_self, 36 get_bounds_index_expr, 37 get_fused_kernel_name, 38 has_free_symbols, 39 is_welford_reduction, 40 parallel_num_threads, 41 Placeholder, 42 sympy_index_symbol, 43 sympy_index_symbol_with_prefix, 44 sympy_product, 45 sympy_subs, 46) 47from ..virtualized import NullKernelHandler, ops, OpsValue, V 48from .common import ( 49 BackendFeature, 50 BracesBuffer, 51 CppWrapperKernelArgs, 52 CSE, 53 CSEVariable, 54 DataTypePropagation, 55 DeferredLine, 56 DTYPE_TO_COMPUTATION_DTYPE, 57 IndentedBuffer, 58 Kernel, 59 KernelArgs, 60 OpOverrides, 61 OptimizationContext, 62) 63from .cpp_utils import ( 64 _get_dtype_from_loopbodies, 65 _get_loop_body, 66 cexpr, 67 cexpr_index, 68 codegen_rand, 69 CppCSEVariable, 70 DTYPE_TO_CPP, 71 INDEX_TYPE, 72 LocalBufferContext, 73 promote_args, 74 unify_mask_base_type, 75 value_to_cpp, 76) 77 78 79_IS_WINDOWS = sys.platform == "win32" 80 81 82def get_export_declaration(): 83 return "__declspec(dllexport)" if _IS_WINDOWS else "" 84 85 86schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") 87 88NATIVE_OMP_RTYPES = {"+", "*", "^", "||", "min", "max"} 89RTYPE_TO_CPP = { 90 "sum": "+", 91 "prod": "*", 92 "xor_sum": "^", 93 "min": "min", 94 "max": "max", 95 "argmin": "argmin", 96 "argmax": "argmax", 97 "any": "||", 98 "welford_reduce": "welford", 99 "welford_combine": "welford", 100} 101VECTORIZABLE_RTYPES = { 102 "max", 103 "min", 104 "sum", 105 "prod", 106 "xor_sum", 107 "welford_reduce", 108 "welford_combine", 109 "argmin", 110 "argmax", 111 "any", 112} 113 114PYTHON_TO_CPP = { 115 "Tensor": "at::Tensor", 116 "int": "long", 117 "float": "double", 118 "bool": "bool", 119 "str": "std::string", 120 "ScalarType": "c10::ScalarType", 121 "MemoryFormat": "at::MemoryFormat", 122 "Layout": "at::Layout", 123 "Device": "at::Device", 124 "number": "at::Scalar", 125} 126 127CONTAINER_PYTHON_TO_CPP = { 128 "List": "std::vector", 129 "Optional": "std::optional", 130} 131 132DTYPE_LOWP_FP = [ 133 torch.bfloat16, 134 torch.float16, 135] 136 137VECTORIZABLE_DTYPES: List[torch.dtype] = [ 138 torch.float64, 139 torch.float, 140 torch.bfloat16, 141 torch.float16, 142 torch.bool, 143 torch.uint8, 144 torch.int8, 145 torch.int32, 146 torch.int64, 147] 148 149MASKED_VECTORIZABLE_DTYPES: List[torch.dtype] = [ 150 torch.float, 151 torch.bfloat16, 152 torch.float16, 153 torch.uint8, 154 torch.int8, 155] 156 157 158def reduction_init(reduction_type, dtype): 159 if dtype in DTYPE_LOWP_FP: 160 # Since load promotes all half-precision inputs to float, the initial 161 # constant for reduction must be promoted as well 162 dtype = torch.float32 163 if reduction_type in ("xor_sum", "sum", "any"): 164 return 0 165 if reduction_type == "prod": 166 return 1 167 if reduction_type in ("max", "argmax", "min", "argmin"): 168 cdtype = DTYPE_TO_CPP[dtype] 169 min_var = ( 170 f"-std::numeric_limits<{cdtype}>::infinity()" 171 if is_float_dtype(dtype) 172 else f"std::numeric_limits<{cdtype}>::min()" 173 ) 174 max_var = ( 175 f"std::numeric_limits<{cdtype}>::infinity()" 176 if is_float_dtype(dtype) 177 else f"std::numeric_limits<{cdtype}>::max()" 178 ) 179 init_var = min_var if reduction_type in ("max", "argmax") else max_var 180 return ( 181 init_var 182 if reduction_type in ("max", "min") 183 else f"IndexValue<{cdtype}>{{0, {init_var}}}" 184 ) 185 if is_welford_reduction(reduction_type): 186 return f"Welford<{DTYPE_TO_CPP[dtype]}>()" 187 raise AssertionError(reduction_type) 188 189 190def reduction_acc_type(reduction_type, dtype): 191 scalar_type = DTYPE_TO_CPP[DTYPE_TO_COMPUTATION_DTYPE[dtype]] 192 if is_welford_reduction(reduction_type): 193 return f"Welford<{scalar_type}>" 194 if reduction_type in {"argmin", "argmax"}: 195 return f"IndexValue<{scalar_type}>" 196 return scalar_type 197 198 199def reduction_combine( 200 reduction_type, 201 var, 202 next_value, 203 index: Optional[sympy.Symbol] = None, 204 src_dtype=None, 205): 206 is_bool = src_dtype == torch.bool 207 if reduction_type == "sum": 208 conjunction = "|" if is_bool else "+" 209 return f"{var} {conjunction} {next_value}" 210 if reduction_type == "prod": 211 return f"{var} * {next_value}" 212 if reduction_type == "xor_sum": 213 return f"{var} ^ {next_value}" 214 if reduction_type == "any": 215 return f"{var} || {next_value}" 216 if reduction_type in ("min", "max"): 217 return f"{reduction_type}_propagate_nan({var}, {next_value})" 218 if reduction_type == "welford_reduce": 219 return f"welford_combine({var}, {next_value})" 220 if reduction_type == "welford_combine": 221 if isinstance(next_value, tuple): 222 mean, m2, weight = next_value 223 else: 224 mean, m2, weight = reduction_project(reduction_type, next_value) 225 return f"welford_combine({var}, {{{mean}, {m2}, {weight}}})" 226 if reduction_type in ("argmin", "argmax"): 227 if index is not None: 228 return f"{reduction_type}_combine({var}, {next_value}, {index})" 229 else: 230 return f"{reduction_type}_combine({var}, {next_value})" 231 raise AssertionError(reduction_type) 232 233 234def reduction_project(reduction_type, acc): 235 if is_welford_reduction(reduction_type): 236 return f"{acc}.mean", f"{acc}.m2", f"{acc}.weight" 237 elif reduction_type in {"argmin", "argmax"}: 238 return f"{acc}.index" 239 return acc 240 241 242@functools.lru_cache 243def stride_at(index: sympy.Expr, var: sympy.Symbol): 244 if not index.has(var): 245 # see test_torchinductor_dynamic_shapes.py::test_full_boolean_dynamic_shapes_cpu 246 # which has tmp0 = ops.index_expr(s0 >= 1024, torch.bool) and fails below calculation. 247 # in this case, there is no dependencies between index and var. 248 return sympy.Integer(0) 249 replacement = {var: var + 1} 250 new_index = sympy_subs(index, replacement) # type: ignore[arg-type] 251 return sympy.simplify(new_index - index) 252 253 254@functools.lru_cache 255def simplify_index_in_vec_range(index: sympy.Expr, var: sympy.Expr, vec_length: int): 256 """ 257 Simplifies the index expression within the range of a vectorized loop. 258 Given a vectorized loop variable `var` in the range of a loop with `vec_length`, 259 this function transforms the `index` into an equivalent form. It handles 260 simplifications for cases where `var` can be expressed as `vec_length * a + b`, 261 where `b` ranges from 0 to `vec_length - 1`. The function reduces occurrences 262 of `FloorDiv` and `ModularIndexing` in the `index` with best-effort optimizations. 263 264 NOTE: 265 The simplified index expression is intended for analysis purposes only, not 266 for code generation. It replaces `FloorDiv` and `ModularIndexing` with free variables 267 which are not dependent on the loop variable `var` in the vectorized range. Check 268 https://github.com/pytorch/pytorch/pull/117221#discussion_r1449746217 for more details. 269 270 Examples: 271 1. If `var` is `x3` and `vec_length` is 16, and `x3 = 16*a + b`, then 272 `FloorDiv(x3, div)` or `ModularIndexing(x3, div, mod)` becomes a free variable 273 when `div` is divisible by 16. 274 2. `ModularIndexing(x3, 1, mod)` can be simplified to `x3 + c` where `c` is a free 275 variable when `mod` is divisible by 16. 276 """ 277 278 div_freevar_id = 0 279 mod_freevar_id = 0 280 281 def visit_indexing_div(divisor): 282 nonlocal div_freevar_id 283 result = FloorDiv(var, divisor) 284 if sympy.gcd(divisor, vec_length) == vec_length: 285 result = sympy.Symbol(f"{var}_div_c{div_freevar_id}") 286 div_freevar_id += 1 287 return result 288 289 def visit_modular_indexing(divisor, modulus): 290 nonlocal mod_freevar_id 291 result = ModularIndexing(var, divisor, modulus) 292 if sympy.gcd(divisor, vec_length) == vec_length: 293 result = sympy.Symbol(f"{var}_mod_c{mod_freevar_id}") 294 mod_freevar_id += 1 295 elif divisor == 1 and sympy.gcd(modulus, vec_length) == vec_length: 296 result = var + sympy.Symbol(f"{var}_mod_c{mod_freevar_id}") 297 mod_freevar_id += 1 298 return result 299 300 original_index = index 301 302 div = sympy.Wild("divisor", integer=True) 303 if index.has(FloorDiv): 304 index = index.replace(FloorDiv(var, div), visit_indexing_div) 305 306 mod = sympy.Wild("modulus", integer=True) 307 if index.has(ModularIndexing): 308 index = index.replace(ModularIndexing(var, div, mod), visit_modular_indexing) 309 310 index = sympy.simplify(index) 311 if index != original_index: 312 return simplify_index_in_vec_range(index, var, vec_length) 313 314 return index 315 316 317@functools.lru_cache 318def stride_at_vec_range( 319 index: sympy.Expr, var: sympy.Symbol, vec_length: Optional[int] = None 320): 321 if vec_length: 322 index = simplify_index_in_vec_range(index, var, vec_length) 323 return stride_at(index, var) 324 325 326class OuterLoopFusedSchedulerNode(FusedSchedulerNode): 327 @classmethod 328 def fuse( # type: ignore[override] 329 cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode, outer_loop_fusion_depth 330 ): 331 assert node1.scheduler is node2.scheduler 332 assert all( 333 type(node) 334 in ( 335 OuterLoopFusedSchedulerNode, 336 SchedulerNode, 337 FusedSchedulerNode, 338 ) 339 for node in (node1, node2) 340 ) 341 if any(type(node) is OuterLoopFusedSchedulerNode for node in (node1, node2)): 342 return cls( 343 node1.scheduler, 344 ( 345 list(node1.get_outer_nodes()) 346 if type(node1) is OuterLoopFusedSchedulerNode 347 else [ 348 node1, 349 ] 350 ) 351 + ( 352 list(node2.get_outer_nodes()) 353 if type(node2) is OuterLoopFusedSchedulerNode 354 else [ 355 node2, 356 ] 357 ), 358 outer_loop_fusion_depth, 359 ) 360 else: 361 return cls(node1.scheduler, [node1, node2], outer_loop_fusion_depth) # type: ignore[list-item] 362 363 def __init__( 364 self, 365 scheduler: "Scheduler", 366 outer_fused_nodes: List[Union[FusedSchedulerNode, SchedulerNode]], 367 outer_loop_fusion_depth, 368 ): 369 self.outer_fused_nodes: List[ 370 Union[FusedSchedulerNode, SchedulerNode] 371 ] = outer_fused_nodes 372 self.outer_loop_fusion_depth = outer_loop_fusion_depth 373 flatten_snodes = [] 374 for _node in self.outer_fused_nodes: 375 assert isinstance(_node, (SchedulerNode, FusedSchedulerNode)) 376 flatten_snodes.extend(list(_node.get_nodes())) 377 super().__init__(scheduler, flatten_snodes) # type: ignore[arg-type] 378 379 def get_outer_nodes(self): 380 return self.outer_fused_nodes 381 382 def check_outer_fusion_loop_level_attr( 383 self, cpp_kernel_proxy_list, outer_loop_fusion_depth 384 ): 385 # This function ensures that the same tiling split is applied at each loop level within the outer loop fusion depth. 386 # In the fusion stage, we only examine nodes with same vars and reduce. 387 # However, for nodes with same vars and reduce, the loops may still have different tile splits. 388 # For example (test_expr_vec_non_contiguous in test_cpu_repro.py): 389 # * buf0 tiling along the 2nd loop level, buf1 tiling along the 3rd loop level. 390 # If the check failed, we should fall back to standard loop codegen. 391 def _inner( 392 left_loop_level: LoopLevel, 393 right_loop_level: LoopLevel, 394 loop_fusion_depth: int, 395 ) -> bool: 396 # Check if same loop level attr 397 outer_loops_attr_compare_list = [ 398 "var", 399 "size", 400 "offset", 401 "steps", 402 ] 403 if not ( 404 all( 405 getattr(left_loop_level, attr_compare) 406 == getattr(right_loop_level, attr_compare) 407 for attr_compare in outer_loops_attr_compare_list 408 ) 409 ): 410 return False 411 412 assert loop_fusion_depth >= 1 413 if (loop_fusion_depth := loop_fusion_depth - 1) > 0: 414 # If the next loop level is expected to undergo outer loop fusion, 415 # there should be no kernel present at the current loop level. 416 assert ( 417 left_loop_level.kernel is None and right_loop_level.kernel is None 418 ) 419 # Check next loop level attr 420 if any( 421 # Assume no main/tail loop split at any outer loop fusion depth 422 # Given no clear performance benefit for this complex case 423 len(loop_level.inner) != 1 424 for loop_level in [left_loop_level, right_loop_level] 425 ) or not _inner( 426 left_loop_level.inner[0], 427 right_loop_level.inner[0], 428 loop_fusion_depth, 429 ): 430 return False 431 432 return True 433 434 for idx in range(len(cpp_kernel_proxy_list) - 1): 435 left_loop_nest = cpp_kernel_proxy_list[idx].loop_nest 436 right_loop_nest = cpp_kernel_proxy_list[idx + 1].loop_nest 437 if any( 438 # Assume no main/tail loop split at any outer loop fusion depth 439 len(loop_nest.root) != 1 440 for loop_nest in [left_loop_nest, right_loop_nest] 441 ) or not _inner( 442 left_loop_nest.root[0], right_loop_nest.root[0], outer_loop_fusion_depth 443 ): 444 return False 445 446 return True 447 448 def merge_outer_fusion_kernels( 449 self, 450 cpp_kernel_proxy_list, 451 ): 452 loop_nest_list: List[LoopNestWithSplit] = [ 453 kernel.loop_nest for kernel in cpp_kernel_proxy_list 454 ] 455 kernel_group = cpp_kernel_proxy_list[0].kernel_group 456 457 def _merge_outer_fusion_loop_levels( 458 loop_level_nested_list: List[List["LoopLevel"]], 459 outer_loop_fusion_depth, 460 ): 461 assert outer_loop_fusion_depth >= 1 462 # Assume no main/tail loop split at any outer loop fusion depth 463 assert all( 464 len(loop_level_list) == 1 for loop_level_list in loop_level_nested_list 465 ) 466 if (outer_loop_fusion_depth := outer_loop_fusion_depth - 1) >= 1: 467 # Further merge the next loop level 468 next_loop_level_nested_list = [ 469 loop_level_list[0].inner 470 for loop_level_list in loop_level_nested_list 471 ] 472 _merge_outer_fusion_loop_levels( 473 next_loop_level_nested_list, 474 outer_loop_fusion_depth, 475 ) 476 else: 477 outer_loop_fused_kernel = OuterLoopFusedKernel(kernel_group) 478 loop_level_of_first_kernel = loop_level_nested_list[0][0] 479 for kernel_idx in range(len(loop_level_nested_list)): 480 outer_loop_fused_kernel.inner.append( 481 deepcopy(loop_level_nested_list[kernel_idx][0]), 482 ) 483 loop_level_of_first_kernel.inner = [] 484 loop_level_of_first_kernel.kernel = outer_loop_fused_kernel 485 486 # Merge the List[LoopNestWithSplit] from cpp_kernel_proxy_list 487 # into cpp_kernel_proxy_list[0].loop_nest 488 _merge_outer_fusion_loop_levels( 489 [_loop_nest.root for _loop_nest in loop_nest_list], # type: ignore[misc] 490 self.outer_loop_fusion_depth, 491 ) 492 return cpp_kernel_proxy_list[0] 493 494 495class RecordOptimizationContext: 496 def __init__(self, func_name: str = ""): 497 self.func_name = func_name 498 self.current_node: Optional[torch.fx.Node] = None 499 self.opt_ctx: Optional[OptimizationContext] = None 500 501 def __enter__(self): 502 assert V.interpreter 503 assert V.interpreter.current_node 504 505 self.current_node = V.interpreter.current_node 506 assert self.current_node is not None 507 if OptimizationContext.key in self.current_node.meta: 508 self.opt_ctx = self.current_node.meta[OptimizationContext.key] 509 else: 510 self.opt_ctx = OptimizationContext() 511 assert self.opt_ctx is not None 512 self.opt_ctx.ops_name = self.func_name 513 return self 514 515 def __exit__(self, exc_type, exc_val, exc_tb): 516 assert self.current_node 517 assert self.opt_ctx 518 self.current_node.meta[OptimizationContext.key] = self.opt_ctx 519 520 def get_opt_ctx(self): 521 return self.opt_ctx 522 523 def get_fx_node(self): 524 assert self.current_node 525 return self.current_node 526 527 528class CppOverrides(OpOverrides): 529 """Map element-wise ops to C++""" 530 531 @staticmethod 532 def add(a, b): 533 return f"decltype({a})({a} + {b})" 534 535 @staticmethod 536 def sub(a, b): 537 return f"decltype({a})({a} - {b})" 538 539 @staticmethod 540 def mul(a, b): 541 return f"decltype({a})({a} * {b})" 542 543 @staticmethod 544 def to_dtype(x, dtype, src_dtype=None, use_compute_types=True): 545 assert isinstance(x, CppCSEVariable) 546 if src_dtype is None: 547 src_dtype = x.dtype 548 expr = V.kernel.get_to_dtype_expr(x, dtype, src_dtype) 549 csevar = V.kernel.cse.generate(V.kernel.compute, expr) 550 csevar.update_on_args("to_dtype", (x, dtype), {"src_dtype": src_dtype}) 551 if dtype in [torch.bfloat16, torch.float16] and src_dtype == torch.float: 552 """ 553 https://github.com/pytorch/pytorch/issues/115260 554 For FusedSchedulerNode[node1, node2], the node2 loads what node1 stores and the buffer is 555 in low-precision floating point data type. When the output of node1 also serves as the output of the 556 kernel, the result of nodes would be different from the case when output of node1 is not the output 557 of the kernel (where we don't need to insert `to_dtype` for legalization). To address the problem, on 558 storing the lowp node1 output, we also add the inverse dtype conversion to high precision data type 559 to the cse cache. 560 561 Example (pseudo code): 562 node1_output = ... 563 node1_output_lowp = to_dtype(node1_output, dtype=torch.bfloat16) 564 store(buf, node1_output_lowp) 565 node2_input_lowp = load(buf) 566 node2_input = to_dtype(node2_input_lowp, dtype=torch.float) 567 568 Without cse cache trick: 569 node1_output = ... 570 node1_output_lowp = to_dtype(node1_output, dtype=torch.bfloat16) 571 store(buf, node1_output_lowp) 572 node2_input_lowp = node_output_lowp # hit store cache 573 node2_input = to_dtype(node2_input_lowp, dtype=torch.float) 574 575 With cse cache trick: 576 node1_output = ... 577 node1_output_lowp = to_dtype(node1_output, dtype=torch.bfloat16) 578 # also add `to_dtype(node1_input_lowp, dtype=torch.float)` -> `node1_output` to cse cache 579 store(buf, node1_output_lowp) 580 node2_input_lowp = node_output_lowp # hit store cache 581 node2_input = node1_output # hit cse cache 582 """ 583 V.kernel.cache_dtype_convert(x, src_dtype, csevar, dtype) 584 return csevar 585 586 @staticmethod 587 def to_dtype_bitcast(x, dtype, src_dtype): 588 assert dtype in DTYPE_TO_CPP, f"{dtype} missing from {__name__}.DTYPE_TO_CPP" 589 if src_dtype in (torch.float16, torch.bfloat16): 590 # c10::bit_cast requires the source and target have the bitwidth. 591 # Because the input tensor's dtype could be promoted, e.g. from float16 to 592 # float, we have to cast the tensor to its original source dtype before 593 # invoking bit_cast. We also need to convert the bit-casted tensor 594 # back to float to make sure we keep using higher precision values 595 # for the rest of the computation. 596 cast_x = f"c10::convert<{DTYPE_TO_CPP[src_dtype]}>({x})" 597 cast_x = f"c10::bit_cast<{DTYPE_TO_CPP[dtype]}>({cast_x})" 598 return f"c10::convert<{DTYPE_TO_CPP[torch.float32]}>({cast_x})" 599 else: 600 return f"c10::bit_cast<{DTYPE_TO_CPP[dtype]}>({x})" 601 602 @staticmethod 603 def abs(x): 604 return f"std::abs({x})" 605 606 @staticmethod 607 def sin(x): 608 return f"std::sin({x})" 609 610 @staticmethod 611 def cos(x): 612 return f"std::cos({x})" 613 614 @staticmethod 615 def neg(x): 616 return f"decltype({x})(-{x})" 617 618 @staticmethod 619 def exp(x): 620 # return f"Sleef_expf_u10({x})" 621 return f"std::exp({x})" 622 623 @staticmethod 624 def exp2(x): 625 return f"std::exp2({x})" 626 627 @staticmethod 628 def expm1(x): 629 return f"std::expm1({x})" 630 631 @staticmethod 632 def erf(x): 633 return f"std::erf({x})" 634 635 @staticmethod 636 def erfc(x): 637 return f"std::erfc({x})" 638 639 @staticmethod 640 def erfinv(x): 641 return f"calc_erfinv({x})" 642 643 @staticmethod 644 def sqrt(x): 645 return f"std::sqrt({x})" 646 647 @staticmethod 648 def rsqrt(x): 649 return f"1 / std::sqrt({x})" 650 651 @staticmethod 652 def log1p(x): 653 bug = config.cpp.inject_log1p_bug_TESTING_ONLY 654 if bug == "accuracy": 655 return f"{x} + decltype({x})(1)" 656 elif bug is None: 657 return f"std::log1p({x})" 658 else: 659 raise AssertionError( 660 f"unrecognized config cpp.inject_log1p_bug_TESTING_ONLY = {bug!r}" 661 ) 662 663 @staticmethod 664 def tan(x): 665 return f"std::tan({x})" 666 667 @staticmethod 668 def tanh(x): 669 return f"std::tanh({x})" 670 671 @staticmethod 672 def signbit(x): 673 """ 674 On windows std::signbit only support float type. 675 Ref: https://learn.microsoft.com/en-us/cpp/c-runtime-library/reference/signbit?view=msvc-170 676 """ 677 return ( 678 f"std::signbit(static_cast<float>({x}))" 679 if _IS_WINDOWS 680 else f"std::signbit({x})" 681 ) 682 683 @staticmethod 684 def pow(a, b): 685 return f"std::pow({a}, {b})" 686 687 @staticmethod 688 def log(x): 689 return f"std::log({x})" 690 691 @staticmethod 692 def round(x): 693 return f"std::nearbyint({x})" 694 695 @staticmethod 696 def floor(x): 697 return f"std::floor({x})" 698 699 @staticmethod 700 def floordiv(a, b): 701 # a and b are integer type 702 quot = f"{a} / {b}" 703 rem = f"{a} % {b}" 704 return f"(({a} < 0) != ({b} < 0) ? ({rem} != 0 ? {quot} - 1 : {quot}) : {quot})" 705 706 @staticmethod 707 def ceil(x): 708 return f"std::ceil({x})" 709 710 @staticmethod 711 def trunc(x): 712 return f"std::trunc({x})" 713 714 @staticmethod 715 def truncdiv(a, b): 716 # a and b are integer type 717 return f"{a} / {b}" 718 719 @staticmethod 720 def fmod(a, b): 721 return f"std::fmod({a}, {b})" 722 723 @staticmethod 724 def isinf(x): 725 return f"std::isinf({x})" 726 727 @staticmethod 728 def isnan(x): 729 return f"std::isnan({x})" 730 731 @staticmethod 732 def lgamma(x): 733 return f"std::lgamma({x})" 734 735 @staticmethod 736 def acos(x): 737 return f"std::acos({x})" 738 739 @staticmethod 740 def acosh(x): 741 return f"std::acosh({x})" 742 743 @staticmethod 744 def cosh(x): 745 return f"std::cosh({x})" 746 747 @staticmethod 748 def sinh(x): 749 return f"std::sinh({x})" 750 751 @staticmethod 752 def asin(x): 753 return f"std::asin({x})" 754 755 @staticmethod 756 def asinh(x): 757 return f"std::asinh({x})" 758 759 @staticmethod 760 def atan2(x, y): 761 return f"std::atan2({x}, {y})" 762 763 @staticmethod 764 def atan(x): 765 return f"std::atan({x})" 766 767 @staticmethod 768 def atanh(x): 769 return f"std::atanh({x})" 770 771 @staticmethod 772 def copysign(x, y): 773 return f"std::copysign({x}, {y})" 774 775 @staticmethod 776 def frexp(x): 777 cache_keys = f"frexp({x})[0]", f"frexp({x})[1]" 778 if all(cache_key in V.kernel.cse.cache for cache_key in cache_keys): 779 return tuple(V.kernel.cse.cache[cache_key] for cache_key in cache_keys) 780 781 code = BracesBuffer() 782 exponent = V.kernel.cse.newvar() 783 mantissa = V.kernel.cse.newvar() 784 code.writeline(f"int32_t {exponent};") 785 code.writeline(f"auto {mantissa} = std::frexp({x}, &{exponent});") 786 V.kernel.compute.splice(code) 787 cse_vars = (mantissa, exponent) 788 for cache_key, cse_var in zip(cache_keys, cse_vars): 789 V.kernel.cse.cache[cache_key] = cse_var 790 return mantissa, exponent 791 792 @staticmethod 793 def hypot(x, y): 794 return f"std::hypot({x}, {y})" 795 796 @staticmethod 797 def log10(x): 798 return f"std::log10({x})" 799 800 @staticmethod 801 def log2(x): 802 return f"std::log2({x})" 803 804 @staticmethod 805 def nextafter(x, y): 806 return f"std::nextafter({x}, {y})" 807 808 @staticmethod 809 def relu(x): 810 bug = config.cpp.inject_relu_bug_TESTING_ONLY 811 if bug == "compile_error": 812 return "compile error!" 813 elif bug == "runtime_error": 814 return f"{x}; throw 1" 815 elif bug == "accuracy": 816 return f"{x} + decltype({x})(1)" 817 elif bug is None: 818 return f"std::max({x}, decltype({x})(0))" 819 else: 820 raise AssertionError( 821 f"unrecognized config cpp.inject_relu_bug_TESTING_ONLY = {bug!r}" 822 ) 823 824 @staticmethod 825 def minimum(a, b): 826 return f"min_propagate_nan({a}, {b})" 827 828 @staticmethod 829 def maximum(a, b): 830 return f"max_propagate_nan({a}, {b})" 831 832 @staticmethod 833 def where(a, b, c): 834 return f"{a} ? {b} : {c}" 835 836 @staticmethod 837 def mod(a, b): 838 return f"mod({a}, {b})" 839 840 @staticmethod 841 def constant(val, dtype): 842 if dtype in DTYPE_LOWP_FP: 843 # Since load promotes all half-precision inputs to float, constants 844 # must be promoted as well 845 dtype = torch.float32 846 return value_to_cpp(val, DTYPE_TO_CPP[dtype]) 847 848 @staticmethod 849 def index_expr(expr, dtype): 850 idx_str = cexpr(V.kernel.rename_indexing(expr)) 851 var = V.kernel.cse.generate( 852 V.kernel.compute, idx_str, bounds=get_bounds_index_expr(expr) 853 ) 854 return ops.to_dtype(var, dtype) 855 856 @staticmethod 857 def masked(mask, body, other): 858 code = BracesBuffer() 859 860 # Write masked operation into a lambda 861 body_var = V.kernel.cse.newvar() 862 code.writeline(f"auto {body_var} = [&]") 863 with V.kernel.swap_buffers(code), code.indent(): 864 result = body() 865 code.writeline(f"return {result};") 866 code.writeline(";") 867 V.kernel.compute.splice(code) 868 869 # Use the lambda's return type as the type of other 870 other_code = value_to_cpp(other, f"decltype({body_var}())") 871 return f"{mask} ? {body_var}() : {other_code}" 872 873 @staticmethod 874 def logical_and(a, b): 875 return f"{a} && {b}" 876 877 @staticmethod 878 def logical_not(a): 879 return f"!{a}" 880 881 @staticmethod 882 def logical_or(a, b): 883 return f"{a} || {b}" 884 885 @staticmethod 886 def logical_xor(a, b): 887 return f"{a} != {b}" 888 889 @staticmethod 890 def bitwise_and(a, b): 891 return f"decltype({a})({a} & {b})" 892 893 @staticmethod 894 def bitwise_not(a): 895 return f"decltype({a})(~{a})" 896 897 @staticmethod 898 def bitwise_or(a, b): 899 return f"decltype({a})({a} | {b})" 900 901 @staticmethod 902 def bitwise_xor(a, b): 903 return f"decltype({a})({a} ^ {b})" 904 905 @staticmethod 906 def bitwise_left_shift(a, b): 907 return f"decltype({a})({a} << {b})" 908 909 @staticmethod 910 def bitwise_right_shift(a, b): 911 return f"decltype({a})({a} >> {b})" 912 913 @staticmethod 914 def rand(seed: sympy.Expr, offset: sympy.Expr): 915 return f"normalized_rand_cpu({seed}, {offset})" 916 917 @staticmethod 918 def randn(seed: sympy.Expr, offset: sympy.Expr): 919 return f"randn_cpu({seed}, {offset})" 920 921 @staticmethod 922 def randint64(seed: sympy.Expr, offset: sympy.Expr, low, high): 923 return f"randint64_cpu({seed}, {offset}, {low}, {high})" 924 925 @staticmethod 926 def sigmoid(x): 927 return f"decltype({x})(1) / (decltype({x})(1) + std::exp(-{x}))" 928 929 @staticmethod 930 def sign(x): 931 code = BracesBuffer() 932 scalar_zero = f"decltype({x})(0)" 933 scalar_one = f"decltype({x})(1)" 934 code.writeline("[&]()") 935 with code.indent(): 936 code.writeline(f"auto left = {x} > 0 ? {scalar_one} : {scalar_zero};") 937 code.writeline(f"auto right = {x} < 0 ? {scalar_one} : {scalar_zero};") 938 code.writeline("return left - right;") 939 code.writeline("()") 940 return code 941 942 943CppOverrides._initialize_pointwise_overrides("cpp") 944 945 946class CppVecOverrides(CppOverrides): 947 """Map element-wise ops to aten vectorization C++""" 948 949 def __new__(cls, *args, **kargs): 950 self = super().__new__(cls) 951 952 def wrap(func): 953 # `CppVecKernel` generates both scalar ops and vector ops according to 954 # whether the inputs are scalars or vectors while all ops in `CppVecOverrides` 955 # (except for some ops explained below) assume the inputs are vectors. We wrap the ops in 956 # `CppVecOverrides` to broadcast scalar inputs to vectors if needed or fallback to 957 # `CppOverrides` when all inputs are scalars. 958 # 959 # Notes on ops handled separately in their own functions: 960 # `ops.masked`: 961 # needs recursive handling of masked body. 962 # `ops.index_expr`: 963 # needs to further analyze the dependency of the index expression on 964 # the tiling itervar. 965 def wrapper(*args, **kwargs): 966 scalars = [ 967 arg 968 for arg in args 969 if isinstance(arg, (int, sympy.Expr)) 970 or (isinstance(arg, CppCSEVariable) and not arg.is_vec) 971 ] 972 vectors = [ 973 arg 974 for arg in args 975 if isinstance(arg, CppCSEVariable) and arg.is_vec 976 ] 977 new_args = list(args) 978 if scalars and vectors: 979 new_args = [] 980 for arg in args: 981 if isinstance(arg, (int, sympy.Expr)): 982 if isinstance(arg, sympy.Expr) and not arg.is_number: 983 arg = ops.index_expr(arg, torch.int64) 984 else: 985 arg = ops.constant(arg, torch.int64) 986 arg = arg.value if isinstance(arg, OpsValue) else arg 987 new_args.append(arg) 988 989 # DType Promotion 990 if vectors: 991 # We have saw several data type mismatch issues related with index_expr in 992 # the lowering phase of torch.int8. torch.int32, torch.int64. 993 # 1. int32 and int64 in test_torchinductor.py::test_max_pool2d_with_indices_backward3_cpu 994 # 2. int8 and int32 in test_torchinductor.py::test_max_pool2d5_cpu 995 # 3. int32 and fp32 in test_torchinductor_dynamic_shapes.py::test_avg_pool2d8_dynamic_shapes_cpu 996 if len(new_args) == 2: 997 new_args = promote_args(new_args) 998 elif func == CppVecOverrides.where: 999 new_args[1:] = promote_args(new_args[1:]) 1000 1001 # Broadcast scalar args to vector 1002 if scalars and vectors: 1003 assert isinstance(V.kernel, CppVecKernel) 1004 new_args = [ 1005 V.kernel.broadcast(new_arg) 1006 if ( 1007 isinstance(new_arg, CppCSEVariable) 1008 and not new_arg.is_vec 1009 and func 1010 not in [ 1011 CppVecOverrides.rand, 1012 CppVecOverrides.randn, 1013 CppVecOverrides.randint64, 1014 ] 1015 ) 1016 else new_arg 1017 for new_arg in new_args 1018 ] 1019 1020 if vectors: 1021 return func(*new_args, **kwargs) 1022 else: 1023 # fallback to scalar ops 1024 scalar_ops = super(CppVecOverrides, self) 1025 scalar_func = getattr( 1026 scalar_ops, func.__name__, scalar_ops.__getattr__(func.__name__) # type: ignore[attr-defined] 1027 ) 1028 assert scalar_func is not None 1029 return scalar_func(*args, **kwargs) 1030 1031 return wrapper 1032 1033 for name, method in vars(CppVecOverrides).items(): 1034 if getattr(method, "__class__", None) == staticmethod and name not in [ 1035 "masked", 1036 "index_expr", 1037 ]: 1038 setattr(self, name, wrap(method.__func__)) 1039 1040 return self 1041 1042 @staticmethod 1043 def add(a, b): 1044 return f"{a} + {b}" 1045 1046 @staticmethod 1047 def sub(a, b): 1048 return f"{a} - {b}" 1049 1050 @staticmethod 1051 def mul(a, b): 1052 return f"{a} * {b}" 1053 1054 @staticmethod 1055 def truediv(a, b): 1056 return f"{a} / {b}" 1057 1058 @staticmethod 1059 def abs(x): 1060 return f"{x}.abs()" 1061 1062 @staticmethod 1063 def sin(x): 1064 return f"{x}.sin()" 1065 1066 @staticmethod 1067 def cos(x): 1068 return f"{x}.cos()" 1069 1070 @staticmethod 1071 def exp(x): 1072 return f"{x}.exp()" 1073 1074 @staticmethod 1075 def exp2(x): 1076 return f"{x}.exp2()" 1077 1078 @staticmethod 1079 def expm1(x): 1080 # decompose for a better performance 1081 vec_one = f"decltype({x})(1)" 1082 return f"{x}.exp() - {vec_one}" 1083 1084 @staticmethod 1085 def erf(x): 1086 return f"{x}.erf()" 1087 1088 @staticmethod 1089 def erfc(x): 1090 return f"{x}.erfc()" 1091 1092 @staticmethod 1093 def erfinv(x): 1094 return f"{x}.erfinv()" 1095 1096 @staticmethod 1097 def sqrt(x): 1098 return f"{x}.sqrt()" 1099 1100 @staticmethod 1101 def eq(x, y): 1102 assert isinstance(V.kernel, CppVecKernel) 1103 assert isinstance(x, CppCSEVariable) 1104 assert x.dtype is not None 1105 return f"{V.kernel._get_mask_type(x.dtype)}({x} == {y})" 1106 1107 @staticmethod 1108 def ne(x, y): 1109 assert isinstance(V.kernel, CppVecKernel) 1110 assert isinstance(x, CppCSEVariable) 1111 if x.dtype == torch.bool: 1112 assert y.dtype == torch.bool 1113 x_cast, y_cast = unify_mask_base_type(V.kernel.compute, (x, y)) 1114 return f"{x_cast} != {y_cast}" 1115 else: 1116 assert x.dtype is not None 1117 return f"{V.kernel._get_mask_type(x.dtype)}({x} != {y})" 1118 1119 @staticmethod 1120 def lt(x, y): 1121 assert isinstance(V.kernel, CppVecKernel) 1122 assert isinstance(x, CppCSEVariable) 1123 assert x.dtype is not None 1124 return f"{V.kernel._get_mask_type(x.dtype)}({x} < {y})" 1125 1126 @staticmethod 1127 def gt(x, y): 1128 assert isinstance(V.kernel, CppVecKernel) 1129 assert isinstance(x, CppCSEVariable) 1130 assert x.dtype is not None 1131 return f"{V.kernel._get_mask_type(x.dtype)}({x} > {y})" 1132 1133 @staticmethod 1134 def le(x, y): 1135 assert isinstance(V.kernel, CppVecKernel) 1136 assert isinstance(x, CppCSEVariable) 1137 assert x.dtype is not None 1138 return f"{V.kernel._get_mask_type(x.dtype)}({x} <= {y})" 1139 1140 @staticmethod 1141 def ge(x, y): 1142 assert isinstance(V.kernel, CppVecKernel) 1143 assert isinstance(x, CppCSEVariable) 1144 assert x.dtype is not None 1145 return f"{V.kernel._get_mask_type(x.dtype)}({x} >= {y})" 1146 1147 @staticmethod 1148 def and_(x, y): 1149 return f"{x} & {y}" 1150 1151 @staticmethod 1152 def rsqrt(x): 1153 return f"{x}.rsqrt()" 1154 1155 @staticmethod 1156 def pow(a, b): 1157 return f"{a}.pow({b})" 1158 1159 @staticmethod 1160 def log(x): 1161 return f"{x}.log()" 1162 1163 @staticmethod 1164 def round(x): 1165 return f"{x}.round()" 1166 1167 @staticmethod 1168 def floor(x): 1169 return f"{x}.floor()" 1170 1171 @staticmethod 1172 def ceil(x): 1173 return f"{x}.ceil()" 1174 1175 @staticmethod 1176 def trunc(x): 1177 return f"{x}.trunc()" 1178 1179 @staticmethod 1180 def fmod(a, b): 1181 return f"{a}.fmod({b})" 1182 1183 @staticmethod 1184 def lgamma(x): 1185 return f"{x}.lgamma()" 1186 1187 @staticmethod 1188 def logical_and(a, b): 1189 return f"{a} & {b}" 1190 1191 @staticmethod 1192 def logical_not(a): 1193 return f"~{a}" 1194 1195 @staticmethod 1196 def logical_or(a, b): 1197 return f"{a} | {b}" 1198 1199 @staticmethod 1200 def logical_xor(a, b): 1201 return f"{a} ^ {b}" 1202 1203 @staticmethod 1204 def bitwise_and(a, b): 1205 return f"{a} & {b}" 1206 1207 @staticmethod 1208 def bitwise_not(a): 1209 return f"~{a}" 1210 1211 @staticmethod 1212 def bitwise_or(a, b): 1213 return f"{a} | {b}" 1214 1215 @staticmethod 1216 def bitwise_xor(a, b): 1217 return f"{a} ^ {b}" 1218 1219 @staticmethod 1220 def bitwise_left_shift(a, b): 1221 return f"{a} << {b}" 1222 1223 @staticmethod 1224 def bitwise_right_shift(a, b): 1225 return f"{a} >> {b}" 1226 1227 @staticmethod 1228 def load_seed(name, offset): 1229 assert isinstance(V.kernel, CppVecKernel) 1230 return f"{V.kernel.load(name, offset)}" 1231 1232 @staticmethod 1233 def rand(seed, offset): 1234 assert isinstance(V.kernel, CppVecKernel) 1235 code = BracesBuffer() 1236 rand_function = ( 1237 f"result[offset_idx] = normalized_rand_cpu({seed}, offset[offset_idx]);" 1238 ) 1239 return codegen_rand(offset, code, rand_function) 1240 1241 @staticmethod 1242 def randn(seed, offset): 1243 assert isinstance(V.kernel, CppVecKernel) 1244 code = BracesBuffer() 1245 rand_function = f"result[offset_idx] = randn_cpu({seed}, offset[offset_idx]);" 1246 return codegen_rand(offset, code, rand_function) 1247 1248 @staticmethod 1249 def randint64(seed, offset, low, high): 1250 assert isinstance(V.kernel, CppVecKernel) 1251 code = BracesBuffer() 1252 rand_function = f"result[offset_idx] = randint64_cpu({seed}, offset[offset_idx], {low}, {high});" 1253 return codegen_rand(offset, code, rand_function, torch.int64) 1254 1255 @staticmethod 1256 def remainder(a, b): 1257 assert ( 1258 a.dtype == b.dtype 1259 ), "remainder vec implementation expect the same inputs' dtype." 1260 return f"{a} - ({CppVecOverrides.floordiv(a, b)}) * {b}" 1261 1262 @staticmethod 1263 def tan(a): 1264 return f"{a}.tan()" 1265 1266 @staticmethod 1267 def tanh(a): 1268 vec_one = f"decltype({a})(1)" 1269 vec_two = f"decltype({a})(2)" 1270 vec_minus_two = f"decltype({a})(-2)" 1271 return f"{vec_two} / ({vec_one} + ({vec_minus_two} * {a}).exp()) - {vec_one}" 1272 1273 @staticmethod 1274 def reciprocal(a): 1275 return f"{a}.reciprocal()" 1276 1277 @staticmethod 1278 def atan(x): 1279 return f"{x}.atan()" 1280 1281 @staticmethod 1282 def acos(x): 1283 return f"{x}.acos()" 1284 1285 @staticmethod 1286 def asin(x): 1287 return f"{x}.asin()" 1288 1289 @staticmethod 1290 def cosh(x): 1291 return f"{x}.cosh()" 1292 1293 @staticmethod 1294 def sinh(x): 1295 return f"{x}.sinh()" 1296 1297 @staticmethod 1298 def log10(x): 1299 return f"{x}.log10()" 1300 1301 @staticmethod 1302 def log2(x): 1303 return f"{x}.log2()" 1304 1305 @staticmethod 1306 def nextafter(x, y): 1307 return f"{x}.nextafter({y})" 1308 1309 @staticmethod 1310 def copysign(a, b): 1311 return f"{a}.copysign({b})" 1312 1313 @staticmethod 1314 def atan2(a, b): 1315 return f"{a}.atan2({b})" 1316 1317 @staticmethod 1318 def hypot(a, b): 1319 return f"{a}.hypot({b})" 1320 1321 @staticmethod 1322 def atanh(x): 1323 # For real x, atanh(x) = 1/2 * log((1+x)/(1-x)) 1324 vec_one = f"decltype({x})(1)" 1325 vec_one_half = f"decltype({x})(0.5)" 1326 return f"{vec_one_half} * (({vec_one} + {x})/({vec_one} - {x})).log()" 1327 1328 @staticmethod 1329 def asinh(x): 1330 # For real x, asinh(x) = log(x + sqrt(1 + x**2)) 1331 vec_one = f"decltype({x})(1)" 1332 return f"({x} + ({vec_one} + {x}*{x}).sqrt()).log()" 1333 1334 @staticmethod 1335 def acosh(x): 1336 return f"{x}.acosh()" 1337 1338 @staticmethod 1339 def relu(x): 1340 bug = config.cpp.inject_relu_bug_TESTING_ONLY 1341 if bug == "compile_error": 1342 return "compile error!" 1343 elif bug == "runtime_error": 1344 return f"{x}; throw 1" 1345 elif bug == "accuracy": 1346 return f"{x} + decltype({x})(1)" 1347 elif bug is None: 1348 return f"at::vec::clamp_min({x}, decltype({x})(0))" 1349 else: 1350 raise AssertionError( 1351 f"unrecognized config cpp.inject_relu_bug_TESTING_ONLY = {bug!r}" 1352 ) 1353 1354 # TODO: this seems to be dead 1355 @staticmethod 1356 def sigmoid(x): 1357 return f"decltype({x})(1)/(decltype({x})(1) + {x}.neg().exp())" 1358 1359 @staticmethod 1360 def neg(x): 1361 return f"{x}.neg()" 1362 1363 @staticmethod 1364 def floordiv(a, b): 1365 if is_float_dtype(a.dtype): 1366 assert ( 1367 a.dtype == b.dtype 1368 ), "div_floor_floating_vec implementation expect the same inputs' dtype." 1369 return f"div_floor_floating_vec({a}, {b})" 1370 else: 1371 assert all(is_integer_dtype(item.dtype) for item in [a, b]) 1372 # a and b are integer type 1373 _t = f"decltype({a})" 1374 if V.kernel._get_raw_num_vectors(b.dtype) < 1: 1375 # Doing blend to set the remaining bits of b to non-zero 1376 b = f"{_t}::blend<{(1 << V.kernel.tiling_factor) - 1}>({_t}(1), {b})" 1377 quot = f"{a} / {b}" 1378 has_rem = f"({a} % {b} != {_t}(0))" 1379 is_neg = f"(({a} < {_t}(0)) != ({b} < {_t}(0)))" 1380 return f"{_t}::blendv({quot}, {quot} - {_t}(1), {has_rem} & {is_neg})" 1381 1382 @staticmethod 1383 def truncdiv(a, b): 1384 # a and b are integer type 1385 if V.kernel._get_raw_num_vectors(b.dtype) < 1: 1386 # Doing blend to set the remaining bits of b to non-zero 1387 _t = f"decltype({b})" 1388 b = f"{_t}::blend<{(1 << V.kernel.tiling_factor) - 1}>({_t}(1), {b})" 1389 return f"{a} / {b}" 1390 1391 @staticmethod 1392 def minimum(a, b): 1393 if a.dtype == torch.bool: 1394 assert b.dtype == torch.bool 1395 a_cast, b_cast = unify_mask_base_type(V.kernel.compute, (a, b)) 1396 return f"{a_cast} & {b_cast}" 1397 else: 1398 return f"at::vec::minimum({a}, {b})" 1399 1400 @staticmethod 1401 def maximum(a, b): 1402 if a.dtype == torch.bool: 1403 assert b.dtype == torch.bool 1404 a_cast, b_cast = unify_mask_base_type(V.kernel.compute, (a, b)) 1405 return f"{a_cast} | {b_cast}" 1406 else: 1407 return f"at::vec::maximum({a}, {b})" 1408 1409 @staticmethod 1410 def square(a): 1411 return f"{a} * {a}" 1412 1413 @staticmethod 1414 def where(a, b, c): 1415 assert isinstance(V.kernel, CppVecKernel) 1416 if b.dtype == torch.bool: 1417 assert c.dtype == torch.bool 1418 blendv_a, blendv_b, blendv_c = unify_mask_base_type( 1419 V.kernel.compute, (a, b, c) 1420 ) 1421 return f"decltype({blendv_b})::blendv({blendv_c}, {blendv_b}, {blendv_a})" 1422 else: 1423 return f"decltype({b})::blendv({c}, {b}, {V.kernel._get_mask_cast(a, b.dtype)})" 1424 1425 @staticmethod 1426 def sign(x): 1427 code = BracesBuffer() 1428 vec_zero = f"decltype({x})(0)" 1429 vec_one = f"decltype({x})(1)" 1430 blendv_l = f"decltype({x})::blendv({vec_zero}, {vec_one}, {vec_zero} < {x})" 1431 blendv_r = f"decltype({x})::blendv({vec_zero}, {vec_one}, {x} < {vec_zero})" 1432 code.writeline("[&]()") 1433 with code.indent(): 1434 code.writeline(f"auto left = {blendv_l};") 1435 code.writeline(f"auto right = {blendv_r};") 1436 code.writeline("return left - right;") 1437 code.writeline("()") 1438 return code 1439 1440 @staticmethod 1441 def to_dtype(x, dtype, src_dtype=None, use_compute_dtypes=True): 1442 assert dtype in [ 1443 torch.bool, 1444 torch.float64, 1445 torch.float, 1446 torch.bfloat16, 1447 torch.float16, 1448 torch.uint8, 1449 torch.int8, 1450 torch.int32, 1451 torch.int64, 1452 ], f"{__name__} does not support {dtype}" 1453 assert isinstance(x, CppCSEVariable) 1454 src_dtype = x.dtype 1455 expr = V.kernel.get_to_dtype_expr(x, dtype, src_dtype) 1456 csevar = V.kernel.cse.generate(V.kernel.compute, expr) 1457 csevar.update_on_args("to_dtype", (x, dtype), {"src_dtype": src_dtype}) 1458 if dtype in [torch.bfloat16, torch.float16] and src_dtype == torch.float: 1459 V.kernel.cache_dtype_convert(x, src_dtype, csevar, dtype) 1460 return csevar 1461 1462 @staticmethod 1463 def log1p(x): 1464 bug = config.cpp.inject_log1p_bug_TESTING_ONLY 1465 if bug == "accuracy": 1466 return f"{x} + decltype({x})(1)" 1467 elif bug is None: 1468 return f"{x}.log1p()" 1469 else: 1470 raise AssertionError( 1471 f"unrecognized config cpp.inject_log1p_bug_TESTING_ONLY = {bug!r}" 1472 ) 1473 1474 @staticmethod 1475 def masked(mask, body, other): 1476 assert isinstance(V.kernel, CppVecKernel) 1477 code = BracesBuffer() 1478 var = V.kernel.cse.newvar() 1479 with V.kernel.masked(mask) as new_mask: 1480 code.writeline(f"auto {var} = [&]") 1481 with V.kernel.swap_buffers(code), code.indent(): 1482 result = body() 1483 code.writeline(f"return {result};") 1484 code.writeline(";") 1485 V.kernel.compute.splice(code) 1486 1487 dtype = result.dtype 1488 body_code = f"{var}()" 1489 body_code_vec = ( 1490 body_code 1491 if result.is_vec 1492 else f"{V.kernel._get_vec_type(dtype)}({body_code})" 1493 ) 1494 other_code = value_to_cpp(other, DTYPE_TO_CPP[dtype]) 1495 # loading bool as VecMask<float, N> 1496 other_code_vec = ( 1497 f"{V.kernel._get_mask_type()}::from({other_code})" 1498 if dtype == torch.bool 1499 else f"{V.kernel._get_vec_type(dtype)}({other_code})" 1500 ) 1501 assert isinstance(new_mask, CppCSEVariable), new_mask 1502 if new_mask.is_vec: 1503 code = BracesBuffer() 1504 code.writeline("[&]") 1505 with V.kernel.swap_buffers(code), code.indent(): 1506 code.writeline(f"if ({new_mask}.all_zero())") 1507 with code.indent(): 1508 code.writeline(f"return {other_code_vec};") 1509 code.writeline("else") 1510 with code.indent(): 1511 # Create cse variable to reuse kernel.overrides.where 1512 body_vec_var = V.kernel.cse.generate( 1513 V.kernel.compute, 1514 body_code_vec, 1515 ) 1516 other_vec_var = V.kernel.cse.generate( 1517 V.kernel.compute, 1518 other_code_vec, 1519 ) 1520 assert isinstance(body_vec_var, CppCSEVariable), body_vec_var 1521 assert isinstance(other_vec_var, CppCSEVariable), other_vec_var 1522 body_vec_var.dtype = dtype 1523 other_vec_var.dtype = dtype 1524 code.writeline( 1525 f"return {V.kernel.overrides.where(new_mask, body_vec_var, other_vec_var)};" 1526 ) 1527 code.writeline("()") 1528 csevar = V.kernel.cse.generate( 1529 V.kernel.compute, 1530 code, 1531 ) 1532 elif result.is_vec: 1533 csevar = V.kernel.cse.generate( 1534 V.kernel.compute, f"{mask} ? {body_code_vec} : {other_code_vec}" 1535 ) 1536 else: 1537 csevar = V.kernel.cse.generate( 1538 V.kernel.compute, f"{mask} ? {body_code} : {other_code}" 1539 ) 1540 # `result` is explicitly added to the args for correct propagation 1541 # of relevant itervars and vectorization status. 1542 csevar.update_on_args("masked", (mask, body, other, result), {}) 1543 return csevar 1544 1545 @staticmethod 1546 def index_expr(expr, dtype): 1547 assert isinstance(V.kernel, CppVecKernel) 1548 index = V.kernel.rename_indexing(expr) 1549 tiling_var = V.kernel.itervars[V.kernel.tiling_idx] 1550 stride = V.kernel._try_get_const_stride(index, tiling_var) 1551 if stride == 0: 1552 return CppOverrides.index_expr(expr, dtype) 1553 elif stride is not None: 1554 idx = V.kernel.cse.generate( 1555 V.kernel.compute, cexpr(index), bounds=get_bounds_index_expr(expr) 1556 ) 1557 value = ops.to_dtype(idx, dtype) 1558 if isinstance(value, OpsValue): 1559 value = value.value 1560 csevar = V.kernel.arange(value, stride) 1561 else: 1562 csevar = V.kernel._load_or_store_non_contiguous( # type: ignore[assignment] 1563 None, index, dtype, V.kernel.compute 1564 ) 1565 csevar.update_on_args("index_expr", (expr, dtype), {}) 1566 return csevar 1567 1568 @staticmethod 1569 def frexp(x): 1570 cache_keys = f"frexp({x})[0]", f"frexp({x})[1]" 1571 if all(cache_key in V.kernel.cse.cache for cache_key in cache_keys): 1572 return tuple(V.kernel.cse.cache[cache_key] for cache_key in cache_keys) 1573 1574 cdtype = DTYPE_TO_CPP[x.dtype] 1575 size = V.kernel.tail_size if V.kernel.tail_size else V.kernel.tiling_factor 1576 code = BracesBuffer() 1577 exponent = V.kernel.cse.newvar() 1578 mantissa = V.kernel.cse.newvar() 1579 exponent.update_on_args("frexp", (x,), kwargs={}) 1580 mantissa.update_on_args("frexp", (x,), kwargs={}) 1581 n_vec = V.kernel._get_num_vectors(x.dtype) 1582 mantissa_t = ( 1583 f"at::vec::Vectorized<{cdtype}>" 1584 if n_vec == 1 1585 else f"at::vec::VectorizedN<{cdtype}, {n_vec}>" 1586 ) 1587 code.writeline( 1588 f"at::vec::Vectorized<int32_t> {exponent};" 1589 if n_vec == 1 1590 else f"at::vec::VectorizedN<int32_t, {n_vec}> {exponent};" 1591 ) 1592 code.writeline(f"{mantissa_t} {mantissa};") 1593 code.writeline("[&]()") 1594 with code.indent(): 1595 code.writeline( 1596 f"__at_align__ std::array<{cdtype}, {V.kernel.tiling_factor}> tmpbuf;" 1597 ) 1598 code.writeline(f"{x}.store(tmpbuf.data(), {cexpr_index(size)});") 1599 code.writeline( 1600 f"__at_align__ std::array<int32_t, {V.kernel.tiling_factor}> tmpbuf_exponent;" 1601 ) 1602 code.writeline( 1603 f"__at_align__ std::array<{cdtype}, {V.kernel.tiling_factor}> tmpbuf_mantissa;" 1604 ) 1605 code.writeline(f"for (int i = 0; i < {cexpr_index(size)}; i++)") 1606 with code.indent(): 1607 code.writeline( 1608 "tmpbuf_mantissa[i] = std::frexp(tmpbuf[i], &tmpbuf_exponent[i]);" 1609 ) 1610 code.writeline( 1611 f"{exponent} = at::vec::Vectorized<int32_t>::loadu(tmpbuf_exponent.data(), {cexpr_index(size)});" 1612 if n_vec == 1 1613 else f"{exponent} = at::vec::VectorizedN<int32_t, {n_vec}>::loadu(tmpbuf_exponent.data(), {cexpr_index(size)});" 1614 ) 1615 code.writeline( 1616 f"{mantissa} = {mantissa_t}::loadu(tmpbuf_mantissa.data(), {cexpr_index(size)});" 1617 ) 1618 code.writeline("();") 1619 V.kernel.compute.splice(code) 1620 cse_vars = (mantissa, exponent) 1621 for cache_key, cse_var in zip(cache_keys, cse_vars): 1622 V.kernel.cse.cache[cache_key] = cse_var 1623 return mantissa, exponent 1624 1625 @classmethod 1626 def scalarize(cls, scalar_func): 1627 def inner(*args, **kwargs): 1628 assert not kwargs 1629 kernel = V.kernel 1630 assert isinstance(kernel, CppVecKernel) 1631 code = BracesBuffer() 1632 code.writeline("[&]()") 1633 vec_dtype = args[0].dtype 1634 n_vec = kernel._get_num_vectors(vec_dtype) 1635 size = kernel.tail_size if kernel.tail_size else kernel.tiling_factor 1636 scalar_args = [] 1637 cdtype = DTYPE_TO_CPP[vec_dtype] 1638 output_mask = scalar_func.__name__ in ( 1639 "isinf", 1640 "isnan", 1641 "signbit", 1642 ) 1643 octype = "bool" if output_mask else cdtype 1644 octype = ( 1645 DTYPE_TO_CPP[args[-2]] 1646 if (scalar_func.__name__ == "to_dtype_bitcast") 1647 else octype 1648 ) 1649 with code.indent(): 1650 for argidx, arg in enumerate(args): 1651 if isinstance(arg, CppCSEVariable): 1652 assert arg.is_vec 1653 assert arg.dtype == vec_dtype 1654 code.writeline( 1655 f"__at_align__ std::array<{cdtype}, {kernel.tiling_factor}> tmpbuf{argidx};" 1656 ) 1657 code.writeline( 1658 f"{arg}.store(tmpbuf{argidx}.data(), {cexpr_index(size)});" 1659 ) 1660 scalar_args.append(f"tmpbuf{argidx}[i]") 1661 else: 1662 scalar_args.append(arg) 1663 code.writeline( 1664 f"__at_align__ std::array<{octype}, {kernel.tiling_factor}> tmpbuf_out;" 1665 ) 1666 res = scalar_func(*scalar_args) 1667 code.writeline(f"for (int i = 0; i < {cexpr_index(size)}; i++)") 1668 with code.indent(): 1669 code.writeline(f"tmpbuf_out[i] = {res};") 1670 if output_mask: 1671 assert not kernel.tail_size 1672 load_args = "tmpbuf_out.data()" 1673 load_fn = f"at::vec::VecMask<{cdtype},{n_vec}>::from" 1674 else: 1675 load_args = f"tmpbuf_out.data(), {cexpr_index(size)}" 1676 if n_vec == 1: 1677 load_fn = f"at::vec::Vectorized<{octype}>::loadu" 1678 else: 1679 load_fn = f" at::vec::VectorizedN<{octype}, {n_vec}>::loadu" 1680 code.writeline(f"return {load_fn}({load_args});") 1681 code.writeline("()") 1682 return code 1683 1684 return inner 1685 1686 @classmethod 1687 def _initialize_scalarize(cls): 1688 for name, method in vars(CppOverrides).items(): 1689 if getattr(method, "__class__", None) == staticmethod and name not in vars( 1690 CppVecOverrides 1691 ): 1692 func = cls.scalarize(method.__func__) 1693 func.__name__ = name 1694 setattr(cls, name, staticmethod(func)) 1695 1696 1697CppVecOverrides._initialize_pointwise_overrides("cppvec") 1698CppVecOverrides._initialize_scalarize() 1699 1700 1701class CppTile2DOverrides(CppVecOverrides): 1702 @staticmethod 1703 def index_expr(expr, dtype): 1704 assert isinstance(V.kernel, CppTile2DKernel) 1705 expr = V.kernel.transform_indexing(expr) 1706 return CppVecOverrides.index_expr(expr, dtype) 1707 1708 1709class CppKernel(Kernel): 1710 overrides = CppOverrides # type: ignore[assignment] 1711 sexpr = cexpr 1712 newvar_prefix = "auto " 1713 suffix = ";" 1714 1715 def __init__(self, args, num_threads): 1716 super().__init__(args) 1717 self.call_ranges: Optional[Tuple[sympy.Expr, ...]] = None 1718 self.ranges: List[sympy.Expr] = [] 1719 self.itervars: List[sympy.Symbol] = [] 1720 self.reduction_depth = None 1721 self.reduction_prefix = IndentedBuffer() 1722 self.reduction_suffix = IndentedBuffer() 1723 self.parallel_reduction_prefix = IndentedBuffer() 1724 self.parallel_reduction_suffix = IndentedBuffer() 1725 self.local_reduction_init = IndentedBuffer() 1726 self.local_reduction_stores = IndentedBuffer() 1727 self.is_reduction = False 1728 self.non_parallel_reduction_prefix = IndentedBuffer() 1729 self.reduction_cse = CSE(self.newvar_prefix, self.suffix, name_prefix="tmp_acc") 1730 self.weight_recps_cse = CSE( 1731 self.newvar_prefix, self.suffix, name_prefix="wrecps" 1732 ) 1733 self.preloads = IndentedBuffer() 1734 self.poststores = IndentedBuffer() 1735 self.num_threads = num_threads # num_threads the kernel specialized for 1736 self.reduction_omp_dec: Dict[Tuple[str, str], str] = {} 1737 1738 def _gen_parallel_reduction_buffers( 1739 self, 1740 acc, 1741 acc_type, 1742 reduction_type, 1743 dtype, 1744 reduction_combine_fn=reduction_combine, 1745 reduction_init_fn=reduction_init, 1746 welford_weight_reciprocal_vec_fn=None, 1747 ): 1748 if config.cpp.dynamic_threads and not self.parallel_reduction_prefix: 1749 self.parallel_reduction_prefix.writeline( 1750 "int max_threads = omp_get_max_threads();" 1751 ) 1752 acc_local = f"{acc}_local" 1753 num_threads = ( 1754 "max_threads" if config.cpp.dynamic_threads else parallel_num_threads() 1755 ) 1756 acc_per_thread_var_name = f"{acc}_arr" 1757 acc_per_thread = f"{acc_per_thread_var_name}[{num_threads}]" 1758 """ 1759 MSVC don't support dynamic array(VLA). Please use std::unique_ptr to instead of it. 1760 Ref: https://stackoverflow.com/questions/56555406/creating-dynamic-sized-array-using-msvc-c-compiler 1761 MSVC is the only one compiler, which not support VLA. And MSVC can't get good inductor performance. 1762 So, we can use unique_ptr make it works on MSVC. 1763 For other compilers, we continue to use VLA to get best performence. 1764 """ 1765 acc_per_thread_unique_ptr_decl = f"auto {acc_per_thread_var_name} = std::make_unique<{acc_type}[]>({num_threads})" 1766 acc_per_thread_vla_decl = f"{acc_per_thread_var_name}[{num_threads}]" 1767 acc_local_in_array = acc_per_thread.replace(f"[{num_threads}]", "[tid]") 1768 self.local_reduction_init.writeline( 1769 f"{acc_type} {acc_local} = {reduction_init_fn(reduction_type, dtype)};" 1770 ) 1771 self.parallel_reduction_prefix.writeline( 1772 f"{acc_per_thread_unique_ptr_decl};" 1773 if cpp_builder.is_msvc_cl() 1774 else f"{acc_type} {acc_per_thread_vla_decl};" 1775 ) 1776 self.parallel_reduction_prefix.writelines( 1777 [ 1778 f"for (int tid = 0; tid < {num_threads}; tid++)", 1779 "{", 1780 f" {acc_local_in_array} = {reduction_init_fn(reduction_type, dtype)};", 1781 "}", 1782 ], 1783 ) 1784 self.local_reduction_stores.writelines( 1785 [ 1786 f"{acc_local_in_array} = {acc_local};", 1787 ] 1788 ) 1789 self.parallel_reduction_suffix.writelines( 1790 [ 1791 f"for (int tid = 0; tid < {num_threads}; tid++)", 1792 "{", 1793 f" {acc} = {reduction_combine_fn(reduction_type, acc, acc_local_in_array, src_dtype=dtype)};", 1794 "}", 1795 ], 1796 ) 1797 1798 def get_reduction_var_pattern(self, line: str): 1799 return re.search("tmp_acc[0-9]+", line) 1800 1801 def update_stores_with_parallel_reduction(self): 1802 for i, line in enumerate(self.stores._lines): 1803 if isinstance(line, str): 1804 m = self.get_reduction_var_pattern(line) 1805 if m: 1806 var_name = m.group(0) 1807 self.stores._lines[i] = line.replace(var_name, f"{var_name}_local") 1808 1809 @contextlib.contextmanager 1810 def masked(self, mask): 1811 """Context manager to add an additional mask to loads and stores.""" 1812 prior = self._load_mask 1813 if prior: 1814 mask = ops.and_(mask, prior) 1815 if isinstance(mask, OpsValue): 1816 mask = mask.value 1817 assert isinstance(mask, CppCSEVariable) 1818 # see NOTE [dtype of CppCSEVariable] 1819 # mask's dtype should be bool 1820 mask.dtype = torch.bool 1821 1822 self._load_mask = mask 1823 try: 1824 yield mask 1825 finally: 1826 self._load_mask = prior 1827 1828 def scale_index_with_offset( 1829 self, index: sympy.Expr, scale=1, itervar_idx=-1, offset=0 1830 ): 1831 var = self.itervars[itervar_idx] 1832 replacement = {var: var * scale + offset} 1833 new_index = sympy_subs(index, replacement) 1834 return new_index 1835 1836 def index_to_str(self, index: sympy.Expr) -> str: 1837 """ 1838 Convert an index expr to a string that can be used in cpp code. 1839 e.g. a sympy expression "s2" may actually appear as "ks1" in the cpp kernel. 1840 """ 1841 return cexpr(self.rename_indexing(index)) 1842 1843 def index_indirect_depends_on(self, index: sympy.Expr, itervar: sympy.Symbol): 1844 """ 1845 Check if an index has free symbol CppCSEVariable that depends on `itervar`. 1846 """ 1847 return any( 1848 self.cse.varname_map[s.name].depends_on(itervar) # type: ignore[attr-defined] 1849 for s in index.free_symbols 1850 if s.name in self.cse.varname_map # type: ignore[attr-defined] 1851 and isinstance(self.cse.varname_map[s.name], CppCSEVariable) # type: ignore[attr-defined] 1852 ) 1853 1854 def index_depends_on(self, index: sympy.Expr, itervar: sympy.Symbol): 1855 return itervar in index.free_symbols or self.index_indirect_depends_on( 1856 index, itervar 1857 ) 1858 1859 def var_ranges(self): 1860 return dict(zip(self.itervars, self.ranges)) 1861 1862 def check_bounds( 1863 self, 1864 expr: sympy.Expr, 1865 size: sympy.Expr, 1866 lower: bool, 1867 upper: bool, 1868 ): 1869 if not (lower or upper): 1870 return 1871 1872 indirect = free_symbol_is_type(expr, SymT.TMP) 1873 if indirect: 1874 # indexing in compute 1875 csevar = ops.index_expr(expr, torch.int64).value 1876 buffer = V.kernel.compute 1877 else: 1878 # indexing in loads 1879 prior_compute = V.kernel.compute 1880 try: 1881 V.kernel.compute = self.loads 1882 csevar = ops.index_expr(expr, torch.int64).value 1883 finally: 1884 V.kernel.compute = prior_compute 1885 buffer = self.loads 1886 1887 size_str = V.kernel.sexpr(self.rename_indexing(size)) if upper else None 1888 1889 line = self.indirect_assert( 1890 csevar, "0" if lower else None, size_str, self._load_mask 1891 ) 1892 self.cse.generate(buffer, line, assignment=False) 1893 1894 def load(self, name: str, index: sympy.Expr): 1895 var = self.args.input(name) 1896 index = self.rename_indexing(index) 1897 line = f"{var}[{cexpr_index(index)}]" 1898 csevar = self.cse.generate(self.loads, line) 1899 csevar.update_on_args("load", (self, name, index), {}) 1900 return csevar 1901 1902 def store(self, name, index, value, mode=None): 1903 assert "buf" in name 1904 var = self.args.output(name) 1905 index = self.rename_indexing(index) 1906 if mode is None: 1907 line = f"{var}[{cexpr_index(index)}] = {value};" 1908 elif mode == "atomic_add": 1909 if not config.cpp.dynamic_threads and self.num_threads == 1: 1910 line = f"{var}[{cexpr_index(index)}] += {value};" 1911 else: 1912 dtype = V.graph.get_dtype(name) 1913 # mirroring static_cast<float>(...) in load: 1914 value = f"static_cast<{DTYPE_TO_CPP[dtype]}>({value})" 1915 line = f"atomic_add(&{var}[{cexpr_index(index)}], {value});" 1916 else: 1917 raise NotImplementedError(f"store mode={mode}") 1918 self.stores.writeline(DeferredLine(name, line)) 1919 1920 def reduction(self, dtype, src_dtype, reduction_type, value): 1921 argmax_or_argmin = reduction_type in {"argmax", "argmin"} 1922 reduction_key = src_dtype, reduction_type, value 1923 if reduction_key in self.reduction_cse.reduction_cache: 1924 return self.reduction_cse.reduction_cache[reduction_key] 1925 1926 acc = self.reduction_cse.generate( 1927 self.loads, f"reduction {reduction_key}", write=False 1928 ) 1929 self.is_reduction = True 1930 init_dtype = src_dtype if argmax_or_argmin else dtype 1931 acc_type = reduction_acc_type(reduction_type, init_dtype) 1932 self.reduction_prefix.writeline( 1933 f"{acc_type} {acc} = {reduction_init(reduction_type, init_dtype)};" 1934 ) 1935 assert self.reduction_depth is not None 1936 index = self.itervars[self.reduction_depth] 1937 for i in range(self.reduction_depth + 1, len(self.itervars)): 1938 index = index * self.ranges[i] + self.itervars[i] 1939 self.stores.writeline( 1940 f"{acc} = {reduction_combine(reduction_type, acc, value, index)};" 1941 ) 1942 self._gen_parallel_reduction_buffers(acc, acc_type, reduction_type, init_dtype) 1943 result = reduction_project(reduction_type, acc) 1944 self.reduction_cse.reduction_cache[reduction_key] = result 1945 return result 1946 1947 def store_reduction(self, name, index, value): 1948 index = self.rename_indexing(index) 1949 var = self.args.output(name) 1950 self.reduction_suffix.writeline( 1951 DeferredLine(name, f"{var}[{cexpr_index(index)}] = {value};") 1952 ) 1953 1954 def set_ranges(self, lengths, reduction_lengths): 1955 if self.call_ranges: 1956 assert self.call_ranges == tuple(lengths) + tuple( 1957 reduction_lengths 1958 ), f"{self.call_ranges} == {tuple(lengths)} + {tuple(reduction_lengths)}" 1959 assert self.reduction_depth == len(lengths) 1960 else: 1961 self.call_ranges = tuple(lengths) + tuple(reduction_lengths) 1962 self.ranges = [self.rename_indexing(x) for x in self.call_ranges] 1963 self.itervars = [ 1964 sympy_index_symbol_with_prefix(SymT.XBLOCK, n) 1965 for n in range(len(self.ranges)) 1966 ] 1967 self.reduction_depth = len(lengths) 1968 return ( 1969 self.itervars[: self.reduction_depth], 1970 self.itervars[self.reduction_depth :], 1971 ) 1972 1973 def size_hint(self): 1974 return V.graph.sizevars.size_hint( 1975 sympy_product(self.call_ranges), fallback=8192 1976 ) 1977 1978 def codegen_loops_impl(self, loop_nest, code, worksharing): 1979 threads = parallel_num_threads() 1980 assert self.call_ranges is not None 1981 kernels = loop_nest.get_kernels() 1982 has_outer_loop_kernel = any( 1983 isinstance(kernel, OuterLoopFusedKernel) for kernel in kernels 1984 ) 1985 if has_outer_loop_kernel: 1986 assert len(kernels) == 1 1987 assert isinstance(kernels[0], OuterLoopFusedKernel) 1988 par_depth = kernels[0].decide_parallel_depth( 1989 loop_nest.max_parallel_depth(), threads 1990 ) 1991 else: 1992 par_depth = self.decide_parallel_depth( 1993 loop_nest.max_parallel_depth(), threads 1994 ) 1995 1996 with contextlib.ExitStack() as stack: 1997 if par_depth: 1998 if loop_nest.is_reduction_only(): 1999 # need to close the worksharing scope to define reduction vars outside it 2000 worksharing.close() 2001 else: 2002 worksharing.parallel(threads) 2003 loop_nest.mark_parallel(par_depth) 2004 elif threads > 1: 2005 if worksharing.single(): 2006 stack.enter_context(code.indent()) 2007 2008 def gen_loop_kernel(loop: LoopLevel): 2009 def is_parallel_reduction(loop): 2010 root = loop.get_root() 2011 return root.is_reduction and root.parallel 2012 2013 kernels = loop.get_kernels() 2014 assert len(kernels) == 1 2015 if not isinstance( 2016 kernels[0], OuterLoopFusedKernel 2017 ) and is_parallel_reduction(loop): 2018 kernels[0].update_stores_with_parallel_reduction() 2019 gen_kernel(kernels[0]) 2020 2021 def gen_kernel(kernel): 2022 if isinstance(kernel, OuterLoopFusedKernel): 2023 for loop in kernel.inner: 2024 if loop.inner: 2025 gen_loops(loop.inner, loop.is_reduction) 2026 else: 2027 with contextlib.ExitStack() as stack: 2028 # If there is any kernel existing at the final outer loop fusion level, 2029 # the kernel code should be placed within its respective indent to prevent 2030 # the duplication of variable definitions. 2031 stack.enter_context(code.indent()) 2032 gen_loop_kernel(loop) 2033 else: 2034 with contextlib.ExitStack() as stack: 2035 assert kernel 2036 if hasattr(kernel, "codegen_inner_loops"): 2037 code.splice(kernel.preloads) 2038 kernel.codegen_inner_loops(code) 2039 stack.enter_context(code.indent()) 2040 code.splice(kernel.loads) 2041 code.splice(kernel.compute) 2042 code.splice(kernel.stores) 2043 if hasattr(kernel, "codegen_inner_loops"): 2044 code.splice(kernel.poststores) 2045 2046 def get_reduction_code_buffer(loops, buffer="prefix"): 2047 assert buffer in ("prefix", "suffix", "local") 2048 for loop in loops: 2049 for kernel in loop.get_kernels(): 2050 if buffer == "local": 2051 return ( 2052 kernel.local_reduction_init, 2053 kernel.local_reduction_stores, 2054 ) 2055 elif buffer == "suffix": 2056 suffix = kernel.reduction_suffix 2057 if loop.parallel: 2058 suffix = kernel.parallel_reduction_suffix + suffix 2059 return suffix 2060 else: 2061 prefix = kernel.reduction_prefix 2062 if loop.parallel: 2063 prefix = prefix + kernel.parallel_reduction_prefix 2064 else: 2065 prefix = prefix + kernel.non_parallel_reduction_prefix 2066 return prefix 2067 2068 def gen_loops(loops: List[LoopLevel], in_reduction=False): 2069 with contextlib.ExitStack() as stack_outer: 2070 local_reduction_init = local_reduction_stores = None 2071 if loops: 2072 loop = loops[0] 2073 if loop.is_reduction and not in_reduction: 2074 reduction_prefix = get_reduction_code_buffer(loops) 2075 if reduction_prefix: 2076 stack_outer.enter_context(code.indent()) 2077 code.splice(reduction_prefix) 2078 if loop_nest.is_reduction_only() and loop.parallel: 2079 ( 2080 local_reduction_init, 2081 local_reduction_stores, 2082 ) = get_reduction_code_buffer(loops, "local") 2083 worksharing.parallel(threads) 2084 if local_reduction_init: 2085 assert local_reduction_stores 2086 code.splice(local_reduction_init) 2087 2088 for loop in loops: 2089 gen_loop(loop) 2090 2091 if loops: 2092 loop = loops[0] 2093 if loop_nest.is_reduction_only() and loop.parallel: 2094 if local_reduction_stores: 2095 code.splice(local_reduction_stores) 2096 worksharing.close() 2097 if loop.is_reduction and not in_reduction: 2098 code.splice(get_reduction_code_buffer(loops, "suffix")) 2099 2100 def gen_loop(loop: LoopLevel): 2101 with contextlib.ExitStack() as stack: 2102 loop_lines = loop.lines() 2103 if loop_lines is None: 2104 return 2105 code.writelines(loop_lines) 2106 stack.enter_context(code.indent()) 2107 # generate inner loops or loop body 2108 if loop.inner: 2109 gen_loops(loop.inner, loop.is_reduction) 2110 else: 2111 gen_loop_kernel(loop) 2112 2113 stack.enter_context(code.indent()) 2114 if loop_nest.root: 2115 if ( 2116 has_outer_loop_kernel 2117 and isinstance(V.local_buffer_context, LocalBufferContext) 2118 and V.local_buffer_context.local_buffers 2119 ): 2120 # Allocate local buffer 2121 local_buffers = V.local_buffer_context.local_buffers 2122 for local_buffer in local_buffers.values(): 2123 # For dynamic size, rename s to ks 2124 local_buf_size = sympy_product( 2125 [ 2126 self.rename_indexing(size_val) 2127 for size_val in local_buffer.get_layout().size 2128 ] 2129 ) 2130 local_buf_dtype = DTYPE_TO_CPP[local_buffer.get_layout().dtype] 2131 allocate = f"std::make_unique<{local_buf_dtype} []>({cexpr(local_buf_size)})" 2132 local_buffer_name = local_buffer.get_name() 2133 code.splice( 2134 f"std::unique_ptr<{local_buf_dtype} []> buf_{local_buffer_name} = {allocate};" 2135 ) 2136 code.splice( 2137 f"{local_buf_dtype}* {local_buffer_name} = buf_{local_buffer_name}.get();" 2138 ) 2139 gen_loops(loop_nest.root) 2140 else: 2141 gen_kernel(loop_nest.kernel) 2142 2143 def codegen_loops(self, code, worksharing): 2144 loop_nest = LoopNestWithSplit.build(self) 2145 self.codegen_loops_impl(loop_nest, code, worksharing) 2146 2147 @property 2148 def assert_function(self) -> str: 2149 if V.graph.aot_mode: 2150 # TODO: Using AOTI_TORCH_CHECK is causing performance drop for some models 2151 # compared with JIT Inductor which uses TORCH_CHECK 2152 return "AOTI_TORCH_CHECK" 2153 else: 2154 return "TORCH_CHECK" 2155 2156 def decide_parallel_depth(self, max_parallel_depth, threads): 2157 assert self.call_ranges is not None 2158 ranges = self.call_ranges[:max_parallel_depth] 2159 seq = self.size_hint() 2160 par = 1 2161 depth = 0 2162 for expr in ranges: 2163 hint = V.graph.sizevars.size_hint(expr, fallback=8192) 2164 if par >= 2 * threads or par == threads: 2165 break 2166 if seq // threads < config.cpp.min_chunk_size: 2167 # not enough work 2168 break 2169 depth += 1 2170 par *= hint 2171 seq /= hint 2172 # if we assume thread number is dynamic, make sure we 2173 # have at least one parallel scope and let OMP runtime 2174 # to manage the serial vs. parallel. 2175 if config.cpp.dynamic_threads and depth == 0 and len(ranges) > 0: 2176 depth = 1 2177 return depth 2178 2179 @contextlib.contextmanager 2180 def write_to_suffix(self): 2181 prior = (self.loads, self.compute, self.stores, self.cse) 2182 self.loads = IndentedBuffer() 2183 self.compute = IndentedBuffer() 2184 self.stores = IndentedBuffer() 2185 self.cse = self.cse.clone() 2186 yield 2187 self.reduction_suffix.splice(self.loads) 2188 self.reduction_suffix.splice(self.compute) 2189 self.reduction_suffix.splice(self.stores) 2190 (self.loads, self.compute, self.stores, self.cse) = prior 2191 2192 def create_cse_var(self, *args, **kwargs): 2193 return CppCSEVariable(*args, **kwargs) 2194 2195 def get_to_dtype_expr(self, src, dtype, src_dtype): 2196 return f"c10::convert<{DTYPE_TO_CPP[dtype]}>({src})" 2197 2198 def cache_dtype_convert(self, dst, dst_dtype, src, src_dtype): 2199 expr = self.get_to_dtype_expr(src, dst_dtype, src_dtype) 2200 self.cse.cache[expr] = dst 2201 2202 2203class CppVecKernel(CppKernel): 2204 overrides = CppVecOverrides # type: ignore[assignment] 2205 2206 def __init__( 2207 self, 2208 args, 2209 num_threads, 2210 tiling_factor, 2211 tiling_idx, 2212 tail_size=None, 2213 ): 2214 super().__init__(args, num_threads) 2215 self.vec_isa = cpu_vec_isa.pick_vec_isa() 2216 assert self.vec_isa 2217 assert tiling_factor > 0, "Expect pass in Non-Zero tiling_factor explicitly" 2218 self.tiling_factor = tiling_factor 2219 self.tiling_idx = tiling_idx 2220 self.tail_size = tail_size 2221 self.num_elems = tail_size if tail_size else tiling_factor 2222 2223 def _try_get_const_stride(self, index: sympy.Expr, itervar: sympy.Symbol): 2224 if self.index_indirect_depends_on(index, itervar): 2225 return None 2226 for indirect_var in ( 2227 self.cse.varname_map[s.name] # type: ignore[attr-defined] 2228 for s in index.free_symbols 2229 if symbol_is_type(s, SymT.TMP) 2230 ): 2231 assert isinstance(indirect_var, CppCSEVariable) 2232 if indirect_var.is_vec: 2233 return None 2234 stride = stride_at_vec_range(index, itervar, self.tiling_factor) 2235 return stride if stride.is_number else None 2236 2237 def _get_num_vectors(self, dtype: torch.dtype) -> int: 2238 num_vectors = math.ceil( 2239 self.tiling_factor * dtype.itemsize * 8 / self.vec_isa.bit_width() 2240 ) 2241 assert num_vectors >= 1 2242 return num_vectors 2243 2244 def _get_raw_num_vectors(self, dtype: torch.dtype) -> float: 2245 # This utility function is used to check if the vector lanes has been 2246 # fully utilized. For example, uint8 will only use 1/4 of the vector lanes. 2247 return self.tiling_factor * dtype.itemsize * 8 / self.vec_isa.bit_width() 2248 2249 def _get_vec_type(self, dtype: torch.dtype) -> str: 2250 num_vectors = self._get_num_vectors(dtype) 2251 if num_vectors == 1: 2252 return f"at::vec::Vectorized<{DTYPE_TO_CPP[dtype]}>" 2253 else: 2254 return f"at::vec::VectorizedN<{DTYPE_TO_CPP[dtype]},{num_vectors}>" 2255 2256 def _get_mask_type(self, dtype: torch.dtype = torch.float) -> str: 2257 if dtype == torch.bool: 2258 return "" 2259 num_vectors = self._get_num_vectors(dtype) 2260 return f"at::vec::VecMask<{DTYPE_TO_CPP[dtype]},{num_vectors}>" 2261 2262 def _get_mask_cast(self, mask: CppCSEVariable, dtype: torch.dtype) -> str: 2263 assert mask.dtype == torch.bool, repr(mask) 2264 num_vectors = self._get_num_vectors(dtype) 2265 return f"{mask}.template cast<{DTYPE_TO_CPP[dtype]},{num_vectors}>()" 2266 2267 def get_reduction_var_pattern(self, line: str): 2268 return re.search("tmp_acc[0-9]+_vec", line) 2269 2270 def _get_vec_load_line( 2271 self, 2272 var: str, 2273 index: sympy.Expr, 2274 dtype: torch.dtype, 2275 load_mask: Optional[CppCSEVariable] = None, 2276 ): 2277 """ 2278 Get a load line str that loads a vector from `var` at `index` of type `dtype`. 2279 If `load_mask` is not None, we do a masked load accordingly. 2280 Notes on the `dtype`: 2281 1. We always load `self.tiling_factor` number of elements regardless of the `dtype`. 2282 It means we load half of the vector lanes for 16-bit data types and quarter of the 2283 vector lanes for 8-bit data types. 2284 2. `torch.bool` and `torch.uint8` could mean masks and we load them as float mask vectors. 2285 """ 2286 cpp_type = DTYPE_TO_CPP[dtype] 2287 num_vectors = self._get_num_vectors(dtype) 2288 load_mask_str = None 2289 if load_mask: 2290 if not load_mask.is_vec: 2291 # TODO: avoid hard-code torch.float 2292 load_mask_str = f"{self._get_mask_type(torch.float)}::from({load_mask})" 2293 else: 2294 load_mask_str = f"{self._get_mask_cast(load_mask, torch.float)}" 2295 loadbuf = f"{var} + {cexpr_index(index)}" if index != 0 else var 2296 if dtype == torch.bool: 2297 # TODO: should we consider load mask here? 2298 line = f"{self._get_mask_type()}::from({loadbuf})" 2299 else: 2300 line = ( 2301 f"{load_mask_str}.template loadu<{cpp_type},{num_vectors}>({loadbuf})" 2302 if load_mask_str 2303 else f"{self._get_vec_type(dtype)}::loadu({loadbuf}, {cexpr_index(self.num_elems)})" 2304 ) 2305 return line 2306 2307 def _load_or_store_non_contiguous( 2308 self, 2309 var: Optional[str], 2310 index: sympy.Expr, 2311 dtype: torch.dtype, 2312 buffer: Optional[IndentedBuffer] = None, 2313 store_value: Optional[Union[str, CppCSEVariable]] = None, 2314 accu_store: bool = False, 2315 ) -> Optional[CppCSEVariable]: 2316 """ 2317 Load or store a vector in a non-contiguous way. The vector is initialized from an array that is 2318 filled in an inner loop over the tiling factor. 2319 :param var: buffer to load from or store to, i.e. `var[transformed(index)]`. If None, we load the index 2320 as index expression, i.e. `transformed(index)`. 2321 :param index: index into the `var` or the index expression by its own if `var` is None. 2322 The `index` could contain indirect indexing or the tiling itervar. When used in 2323 the inner loop, the index is transformed as follows: 2324 1. the index is linearized along the tiling dim. 2325 2. the indirect indexing vector variables are transformed into arrays over the tiling dim. 2326 :param dtype: data type of `var` or `index` if `var` is None. 2327 :param buffer: the code buffer to write the generated code to. If None, we write to `self.loads`. 2328 :param store_value: the value to store. If None, we load the vector. 2329 :param accu_store: whether accumulate the store_value to store_ptr. If True, a store_value should be provided 2330 :return: a CppCSEVariable that represents the loaded vector or None if it is a store. 2331 """ 2332 assert not store_value or var is not None, "store var must be provided" 2333 if accu_store: 2334 assert store_value 2335 if buffer is None: 2336 buffer = self.loads 2337 2338 def get_result_size(dtype: torch.dtype) -> int: 2339 if dtype.itemsize < 4: 2340 return self.num_elems * (4 // dtype.itemsize) 2341 else: 2342 return self.num_elems 2343 2344 def get_tiling_size(dtype: torch.dtype) -> int: 2345 if dtype.itemsize < 4: 2346 return self.tiling_factor * (4 // dtype.itemsize) 2347 else: 2348 return self.tiling_factor 2349 2350 def vec_to_array(vec_var: CppCSEVariable) -> CppCSEVariable: 2351 assert vec_var.is_vec 2352 code = BracesBuffer() 2353 code.writeline("[&]") 2354 with code.indent(): 2355 vec_dtype = vec_var.dtype 2356 assert vec_dtype is not None 2357 if vec_dtype == torch.bool: 2358 vec_dtype = torch.float 2359 result_size = get_result_size(vec_dtype) 2360 tiling_size = get_tiling_size(vec_dtype) 2361 code.writeline( 2362 f"__at_align__ std::array<{DTYPE_TO_CPP[vec_dtype]}, {tiling_size}> tmpbuf;" 2363 ) 2364 line = f"{vec_var}.store(tmpbuf.data(), {cexpr_index(result_size)});" 2365 code.writeline(line) 2366 code.writeline("return tmpbuf;") 2367 code.writeline("()") 2368 csevar = self.cse.generate(buffer, code) 2369 assert isinstance(csevar, CppCSEVariable) 2370 return csevar 2371 2372 code = BracesBuffer() 2373 code.writeline("[&]") 2374 with code.indent(): 2375 result_size = get_result_size(dtype) 2376 tiling_size = get_tiling_size(dtype) 2377 result_declare = ( 2378 f"__at_align__ std::array<{DTYPE_TO_CPP[dtype]}, {tiling_size}> tmpbuf;" 2379 ) 2380 code.writeline(result_declare) 2381 if store_value: 2382 code.writeline( 2383 f"{store_value}.store(tmpbuf.data(), {cexpr_index(result_size)});" 2384 ) 2385 itervar_inner = sympy_index_symbol( 2386 f"{self.itervars[self.tiling_idx]}_inner" 2387 ) 2388 replacements = {} 2389 for indirect_var in ( 2390 self.cse.varname_map[s.name] # type: ignore[attr-defined] 2391 for s in index.free_symbols 2392 if symbol_is_type(s, SymT.TMP) 2393 ): 2394 assert isinstance(indirect_var, CppCSEVariable) 2395 if indirect_var.is_vec: 2396 array_var = vec_to_array(indirect_var) 2397 replacements[indirect_var] = f"{array_var}[{itervar_inner}]" 2398 index = self.scale_index_with_offset( 2399 index, itervar_idx=self.tiling_idx, offset=itervar_inner 2400 ) 2401 load_mask = None 2402 if self._load_mask is not None: 2403 assert not store_value, "unexpected store with load mask" 2404 assert isinstance(self._load_mask, CppCSEVariable), self._load_mask 2405 if self._load_mask.is_vec: 2406 load_mask = f"{self._load_mask}.is_masked({itervar_inner})" 2407 else: 2408 load_mask = f"{self._load_mask} != 0" 2409 if cpp_builder.is_gcc(): 2410 code.writeline(f"#pragma GCC unroll {self.tiling_factor}") 2411 else: 2412 code.writeline(f"#pragma unroll {self.tiling_factor}") 2413 code.writeline( 2414 f"for (long {itervar_inner} = 0; " 2415 + f"{itervar_inner} < {cexpr_index(self.num_elems)}; " 2416 + f"{itervar_inner}++)" 2417 ) 2418 with code.indent(), contextlib.ExitStack() as stack: 2419 index_c = cexpr_index(index) 2420 for indirect_var in replacements: 2421 index_c = re.sub( 2422 r"\b" + f"{indirect_var}" + r"\b", 2423 replacements[indirect_var], 2424 index_c, 2425 ) 2426 rhs = f"{var}[{index_c}]" if var is not None else f"{index_c}" 2427 if load_mask: 2428 code.writeline(f"if ({load_mask})") 2429 stack.enter_context(code.indent()) 2430 if store_value: 2431 conjunction = "+=" if accu_store else "=" 2432 code.writeline(f"{rhs} {conjunction} tmpbuf[{itervar_inner}];") 2433 else: 2434 code.writeline(f"tmpbuf[{itervar_inner}] = {rhs};") 2435 if not store_value: 2436 load_line = self._get_vec_load_line("tmpbuf.data()", 0, dtype) # type: ignore[arg-type] 2437 code.writeline(f"return {load_line};") 2438 code.writeline("()") 2439 if store_value: 2440 code.writeline(";") 2441 buffer.splice(code) 2442 return None 2443 else: 2444 csevar = self.cse.generate(buffer, code) 2445 assert isinstance(csevar, CppCSEVariable) 2446 csevar.is_vec = True 2447 return csevar 2448 2449 def load(self, name: str, index: sympy.Expr): 2450 var = self.args.input(name) 2451 index = self.rename_indexing(index) 2452 dtype = V.graph.get_dtype(name) 2453 tiling_var = self.itervars[self.tiling_idx] 2454 stride = self._try_get_const_stride(index, tiling_var) 2455 if stride == 0: 2456 # load scalar and lazily broadcast it on demand 2457 return super().load(name, index) 2458 elif stride == 1: 2459 # load contiguously 2460 line = self._get_vec_load_line(var, index, dtype, self._load_mask) 2461 csevar = self.cse.generate(self.loads, line) # type: ignore[assignment] 2462 else: 2463 csevar = self._load_or_store_non_contiguous(var, index, dtype) # type: ignore[assignment] 2464 assert isinstance(csevar, CppCSEVariable) 2465 csevar.update_on_args("load", (self, name, index), {}) 2466 csevar.is_vec = True 2467 return csevar 2468 2469 def _get_store_line( 2470 self, 2471 value: Union[str, CppCSEVariable], 2472 var: str, 2473 index: sympy.Expr, 2474 dtype: torch.dtype, 2475 accu_store: bool = False, 2476 ): 2477 """ 2478 Get a store line buffer that stores `value` into `var` at `index` of `dtype`. It handles 2479 both contiguous and non-contiguous store cases. 2480 :param value: Vectorized type templaterized on `dtype`. 2481 :param var: buffer to store into. 2482 :index: index into the `var`. 2483 """ 2484 # when value's type is str (e.g., welford reduction), caller should make sure 2485 # it is a vector 2486 assert isinstance(value, str) or ( 2487 isinstance(value, CppCSEVariable) and value.is_vec 2488 ), value 2489 tiling_var = self.itervars[self.tiling_idx] 2490 var_expr = f"{var} + {cexpr_index(index)}" 2491 stride = self._try_get_const_stride(index, tiling_var) 2492 code = IndentedBuffer() 2493 if stride == 1: 2494 if dtype == torch.float and self.tail_size is None: 2495 code.writeline(f"{value}.store({var_expr});") 2496 else: 2497 code.writeline( 2498 f"{value}.store({var_expr}, {cexpr_index(self.num_elems)});" 2499 ) 2500 else: 2501 self._load_or_store_non_contiguous( 2502 var, index, dtype, buffer=code, store_value=value, accu_store=accu_store 2503 ) 2504 return code 2505 2506 def store(self, name, index, value, mode=None): 2507 assert "buf" in name 2508 assert isinstance(value, CppCSEVariable), value 2509 if not value.is_vec: 2510 # this happens when we store a scalar into a vectorized buffer like "fill" 2511 value = self.broadcast(value) 2512 var = self.args.output(name) 2513 index = self.rename_indexing(index) 2514 dtype = V.graph.get_dtype(name) 2515 if mode is None: 2516 code = self._get_store_line(value, var, index, dtype) 2517 self.stores.splice(code.map(lambda x: DeferredLine(name, x))) 2518 elif mode == "atomic_add": 2519 if not config.cpp.dynamic_threads and self.num_threads == 1: 2520 code = self._get_store_line( 2521 f"{value}", 2522 var, 2523 index, 2524 dtype, 2525 accu_store=True, 2526 ) 2527 self.stores.splice(code.map(lambda x: DeferredLine(name, x))) 2528 else: 2529 n_src = self._get_num_vectors(dtype) 2530 n_idx = self._get_num_vectors(torch.int64) 2531 cdtype = DTYPE_TO_CPP[dtype] 2532 index = ops.index_expr(index, torch.int64).value 2533 assert index.is_vec 2534 line = f"atomic_add_vec<{cdtype}, {n_idx}, {n_src}>({var}, {index}, {value});" 2535 self.stores.writeline(DeferredLine(name, line)) 2536 else: 2537 raise NotImplementedError(f"store mode={mode}") 2538 2539 def reduction(self, dtype, src_dtype, reduction_type, value): 2540 assert reduction_type in VECTORIZABLE_RTYPES 2541 argmax_or_argmin = reduction_type in {"argmax", "argmin"} 2542 horizontal_reduction = self.tiling_idx >= self.reduction_depth 2543 init_dtype = src_dtype if argmax_or_argmin else dtype 2544 assert isinstance(value, CppCSEVariable), value 2545 2546 if not value.is_vec: 2547 value = self.broadcast(value) 2548 2549 reduction_key = src_dtype, reduction_type, value 2550 if reduction_key in self.reduction_cse.reduction_cache: 2551 return self.reduction_cse.reduction_cache[reduction_key] 2552 2553 vec_ns = "at::vec" 2554 vec = f"{vec_ns}::Vectorized<{DTYPE_TO_CPP[dtype]}>" 2555 acc_type = reduction_acc_type(reduction_type, init_dtype) 2556 acc_type_vec = self.reduction_acc_type_vec(reduction_type, init_dtype) 2557 2558 acc = self.reduction_cse.generate( 2559 self.loads, f"reduction {reduction_key}", write=False 2560 ) 2561 acc_vec = f"{acc}_vec" 2562 self.is_reduction = True 2563 self.reduction_prefix.writeline( 2564 f"{acc_type} {acc} = {reduction_init(reduction_type, init_dtype)};" 2565 ) 2566 self.reduction_prefix.writeline( 2567 f"{acc_type_vec} {acc_vec} = {self.reduction_init_vec(reduction_type, init_dtype)};" 2568 ) 2569 if reduction_type == "welford_reduce": 2570 # save the reciprocal of weights for welford reduce 2571 assert self.reduction_depth is not None 2572 # use masked acc_vec for tail vec kernel 2573 self.reduction_prefix.writeline( 2574 f"{acc_type_vec} masked_{acc_vec} = {self.reduction_init_vec(reduction_type, dtype)};" 2575 ) 2576 reduction_size = functools.reduce( 2577 lambda x, y: x * y, self.ranges[self.reduction_depth :] 2578 ) 2579 reduction_factor = ( 2580 self.tiling_factor if self.tiling_idx >= self.reduction_depth else 1 2581 ) 2582 self.weight_recp_vec_range = FloorDiv(reduction_size, reduction_factor) 2583 if self.weight_recp_vec_range not in self.weight_recps_cse.reduction_cache: 2584 self.weight_recps_val = self.weight_recps_cse.generate( 2585 self.compute, f"reduction {self.weight_recp_vec_range}", write=False 2586 ) 2587 self.weight_recps_cse.reduction_cache[ 2588 self.weight_recp_vec_range 2589 ] = self.weight_recps_val 2590 self.non_parallel_reduction_prefix.writeline( 2591 self.welford_weight_reciprocal_vec(dtype) 2592 ) 2593 # generate weight_recps for parallel reduction 2594 num_threads = ( 2595 "max_threads" 2596 if config.cpp.dynamic_threads 2597 else parallel_num_threads() 2598 ) 2599 self.local_reduction_init.writeline( 2600 self.welford_weight_reciprocal_vec(dtype, num_threads) 2601 ) 2602 else: 2603 self.weight_recps_val = self.weight_recps_cse.reduction_cache[ 2604 self.weight_recp_vec_range 2605 ] 2606 # use masked acc_vec for tail vec kernel 2607 acc_vec_ = f"masked_{acc_vec}" if self.tail_size else acc_vec 2608 self.stores.writeline( 2609 f"{acc_vec_} = {self.reduction_combine_vec(reduction_type, acc_vec_, value, True)};" 2610 ) 2611 else: 2612 assert self.reduction_depth is not None 2613 index = self.itervars[self.reduction_depth] 2614 for i in range(self.reduction_depth + 1, len(self.itervars)): 2615 index = index * self.ranges[i] + self.itervars[i] 2616 combine = self.reduction_combine_vec( 2617 reduction_type, 2618 acc_vec, 2619 value, 2620 index=index, 2621 horizontal_reduction=horizontal_reduction, 2622 src_dtype=src_dtype, 2623 ) 2624 self.stores.writeline(f"{acc_vec} = {combine};") 2625 self._gen_parallel_reduction_buffers( 2626 acc, 2627 acc_type, 2628 reduction_type, 2629 init_dtype, 2630 ) 2631 self._gen_parallel_reduction_buffers( 2632 acc_vec, 2633 acc_type_vec, 2634 reduction_type, 2635 init_dtype, 2636 reduction_combine_fn=self.reduction_combine_vec, 2637 reduction_init_fn=self.reduction_init_vec, 2638 ) 2639 if reduction_type == "welford_reduce": 2640 # use masked acc_vec for tail vec kernel 2641 self._gen_parallel_reduction_buffers( 2642 f"masked_{acc_vec}", 2643 acc_type_vec, 2644 reduction_type, 2645 dtype, 2646 reduction_combine_fn=self.reduction_combine_vec, 2647 reduction_init_fn=self.reduction_init_vec, 2648 ) 2649 tmpvar: Union[str, CSEVariable] 2650 is_bool = dtype == torch.bool 2651 if horizontal_reduction: 2652 # Horizontal reduction 2653 if is_welford_reduction(reduction_type): 2654 assert self._get_num_vectors(dtype) in [ 2655 1, 2656 2, 2657 ], "Welford reduction does not support VectorizedN (N>2)" 2658 next_value = f"welford_vec_reduce_all({acc_vec})" 2659 masked_next_value = f"welford_vec_reduce_all(masked_{acc_vec})" 2660 self.reduction_suffix.writeline( 2661 f"{acc} = {reduction_combine(reduction_type, acc, masked_next_value)};" 2662 ) 2663 elif argmax_or_argmin: 2664 next_value = f"{reduction_type}_vec_reduce_all({acc_vec})" 2665 elif is_bool: 2666 if reduction_type in ( 2667 "any", 2668 "sum", 2669 "max", 2670 ): 2671 next_value = f"!{acc_vec}.all_zero()" 2672 else: 2673 assert reduction_type == "min" 2674 next_value = f"{acc_vec}.all_masked()" 2675 else: 2676 reduce_all_body = ( 2677 "{ return " 2678 + self.reduction_combine_vec(reduction_type, "x", "y") 2679 + "; }" 2680 ) 2681 is_bool = dtype == torch.bool 2682 # we are using at::vec::VecMask<float, N> for bool 2683 vec_dtype = torch.float if is_bool else dtype 2684 vec = f"at::vec::Vectorized<{DTYPE_TO_CPP[vec_dtype]}>" 2685 vec_reduce_all_func = f"at::vec::vec_reduce_all<{DTYPE_TO_CPP[vec_dtype]}, {self._get_num_vectors(vec_dtype)}>" 2686 next_value = f"{vec_reduce_all_func}([]({vec}& x, {vec}& y) {reduce_all_body}, {acc_vec})" 2687 2688 self.reduction_suffix.writeline( 2689 f"{acc} = {reduction_combine(reduction_type, acc, next_value, src_dtype=src_dtype)};" 2690 ) 2691 tmpvar = acc 2692 else: 2693 tmpvar = acc_vec 2694 if is_welford_reduction(reduction_type): 2695 masked_tmpvar = f"masked_{tmpvar}" 2696 self.reduction_suffix.writeline( 2697 f"{tmpvar} = {reduction_combine(reduction_type, tmpvar, masked_tmpvar)};" 2698 ) 2699 2700 result = reduction_project(reduction_type, tmpvar) 2701 self.reduction_cse.reduction_cache[reduction_key] = result 2702 return result 2703 2704 def store_reduction(self, name, index, value): 2705 index = self.rename_indexing(index) 2706 var = self.args.output(name) 2707 out_dtype = V.graph.get_dtype(name) 2708 dtype = ( 2709 (out_dtype if out_dtype == torch.double else torch.float) 2710 if out_dtype.is_floating_point 2711 else torch.int64 2712 ) 2713 out_num_vectors = V.kernel._get_num_vectors(out_dtype) 2714 src_num_vectors = V.kernel._get_num_vectors(dtype) 2715 code = IndentedBuffer() 2716 if self.tiling_idx >= self.reduction_depth: 2717 # Horizontal reduction 2718 code.writeline( 2719 f"{var}[{cexpr_index(index)}] = static_cast<{DTYPE_TO_CPP[out_dtype]}>({value});" 2720 ) 2721 else: 2722 # Vertical reduction 2723 if out_dtype != dtype: 2724 converted_value = f"{DTYPE_TO_CPP[out_dtype]}_{value}" 2725 if out_dtype == torch.bool: 2726 convert = f"{value}.template cast<bool,{self._get_num_vectors(torch.bool)}>()" 2727 else: 2728 if src_num_vectors == out_num_vectors == 1: 2729 convert = ( 2730 f"at::vec::convert<{DTYPE_TO_CPP[out_dtype]}>({value})" 2731 ) 2732 else: 2733 convert = ( 2734 f"at::vec::convert<{DTYPE_TO_CPP[out_dtype]}," 2735 f"{out_num_vectors},{DTYPE_TO_CPP[dtype]},{src_num_vectors}>({value})" 2736 ) 2737 code.writeline(f"auto {converted_value} = {convert};") 2738 value = converted_value 2739 code.splice(self._get_store_line(value, var, index, out_dtype)) 2740 self.reduction_suffix.splice(code.map(lambda x: DeferredLine(name, x))) 2741 2742 def broadcast(self, scalar_var: CppCSEVariable) -> CppCSEVariable: 2743 assert not scalar_var.is_vec 2744 if scalar_var.dtype == torch.bool: 2745 vec_var = self.cse.generate( 2746 self.compute, f"{self._get_mask_type()}::from({scalar_var.name})" 2747 ) 2748 else: 2749 assert scalar_var.dtype is not None 2750 vec_var = self.cse.generate( 2751 self.compute, 2752 f"{self._get_vec_type(scalar_var.dtype)}({scalar_var.name})", 2753 ) 2754 assert isinstance(vec_var, CppCSEVariable) 2755 vec_var.dtype = scalar_var.dtype 2756 vec_var.dependent_itervars = scalar_var.dependent_itervars 2757 vec_var.is_vec = True 2758 return vec_var 2759 2760 def arange(self, index: CppCSEVariable, stride: sympy.Symbol) -> CppCSEVariable: 2761 assert not index.is_vec 2762 assert index.dtype is not None 2763 csevar = self.cse.generate( 2764 self.compute, 2765 f"{self._get_vec_type(index.dtype)}::arange({index}, {stride})", 2766 ) 2767 assert isinstance(csevar, CppCSEVariable) 2768 csevar.dtype = index.dtype 2769 csevar.is_vec = True 2770 return csevar 2771 2772 def reduction_init_vec(self, reduction_type, dtype): 2773 scalar_type = DTYPE_TO_COMPUTATION_DTYPE[dtype] 2774 vec_type = self._get_vec_type(scalar_type) 2775 2776 if is_welford_reduction(reduction_type): 2777 return f"Welford<{vec_type}>()" 2778 2779 if reduction_type in {"argmin", "argmax"}: 2780 cdtype = DTYPE_TO_CPP[scalar_type] 2781 acc_type = self.reduction_acc_type_vec(reduction_type, dtype) 2782 if reduction_type == "argmin": 2783 val = ( 2784 f"std::numeric_limits<{cdtype}>::infinity()" 2785 if is_float_dtype(dtype) 2786 else f"std::numeric_limits<{cdtype}>::max()" 2787 ) 2788 else: 2789 val = ( 2790 f"-std::numeric_limits<{cdtype}>::infinity()" 2791 if is_float_dtype(dtype) 2792 else f"std::numeric_limits<{cdtype}>::min()" 2793 ) 2794 return f"{acc_type}({val})" 2795 2796 if reduction_type == "any": 2797 return f"{self._get_mask_type()}::from(0)" 2798 2799 scalar_init = reduction_init(reduction_type, dtype) 2800 vec_init = f"{vec_type}({scalar_init})" 2801 if dtype == torch.bool: 2802 assert reduction_type in ("min", "max", "sum") 2803 return f"{self._get_mask_type()}::from({scalar_init})" 2804 return vec_init 2805 2806 def reduction_acc_type_vec(self, reduction_type, dtype): 2807 scalar_type = DTYPE_TO_COMPUTATION_DTYPE[dtype] 2808 vec_type = self._get_vec_type(scalar_type) 2809 if is_welford_reduction(reduction_type): 2810 return f"Welford<{vec_type}>" 2811 if reduction_type in {"argmin", "argmax"}: 2812 n_src = self._get_num_vectors(scalar_type) 2813 n_idx = self._get_num_vectors(torch.int64) 2814 return f"IndexValueVec<{DTYPE_TO_CPP[scalar_type]}, {n_src}, {n_idx}>" 2815 if dtype == torch.bool: 2816 assert reduction_type in ("min", "max", "any", "sum") 2817 return f"{self._get_mask_type()}" 2818 return vec_type 2819 2820 def welford_weight_reciprocal_vec(self, dtype, num_threads=None): 2821 vec_num_range_thread = ( 2822 CeilDiv(self.weight_recp_vec_range, num_threads) 2823 if num_threads 2824 else self.weight_recp_vec_range 2825 ) 2826 vec_num_range_thread_expr = cexpr_index(vec_num_range_thread) 2827 return ( 2828 f"static WeightRecp<{self._get_vec_type(dtype)}> {self.weight_recps_val}" 2829 f"(" 2830 f"{vec_num_range_thread_expr}" 2831 f");" 2832 ) 2833 2834 def reduction_combine_vec( 2835 self, 2836 reduction_type, 2837 var, 2838 next_value, 2839 use_weight_recps=False, 2840 index: Optional[sympy.Symbol] = None, 2841 horizontal_reduction: Optional[bool] = None, 2842 src_dtype: Optional[torch.dtype] = torch.float32, 2843 ): 2844 is_bool = src_dtype == torch.bool 2845 if reduction_type == "max": 2846 if self.tail_size: 2847 return f"max_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" 2848 else: 2849 return ( 2850 f"{var} | {next_value}" 2851 if is_bool 2852 else f"at::vec::maximum({var}, {next_value})" 2853 ) 2854 elif reduction_type == "min": 2855 if self.tail_size: 2856 return f"min_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" 2857 else: 2858 return ( 2859 f"{var} & {next_value}" 2860 if is_bool 2861 else f"at::vec::minimum({var}, {next_value})" 2862 ) 2863 elif reduction_type == "sum": 2864 if self.tail_size: 2865 return f"sum_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" 2866 else: 2867 conjunction = "|" if is_bool else "+" 2868 return f"{var} {conjunction} {next_value}" 2869 elif reduction_type == "prod": 2870 if self.tail_size: 2871 return f"prod_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" 2872 else: 2873 return f"{var} * {next_value}" 2874 elif reduction_type == "xor_sum": 2875 if self.tail_size: 2876 return f"xor_sum_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" 2877 else: 2878 return f"{var} ^ {next_value}" 2879 elif reduction_type == "welford_reduce": 2880 if use_weight_recps: 2881 if self.tail_size: 2882 return f"welford_combine({var}, {next_value}, {cexpr_index(self.tail_size)}, &{self.weight_recps_val})" 2883 else: 2884 return f"welford_combine({var}, {next_value}, &{self.weight_recps_val})" 2885 else: 2886 if self.tail_size: 2887 return f"welford_combine({var}, {next_value}, {cexpr_index(self.tail_size)})" 2888 else: 2889 return f"welford_combine({var}, {next_value})" 2890 elif reduction_type == "welford_combine": 2891 if isinstance(next_value, tuple): 2892 # When reading a value from Inductor IR we have a tuple of variable names 2893 mean, m2, weight = next_value 2894 else: 2895 # When combining intermediate accumulators we have a Welford<T> struct 2896 mean, m2, weight = reduction_project(reduction_type, next_value) 2897 if self.tail_size: 2898 return f"welford_combine({var}, {{{mean}, {m2}, {weight}}}, {cexpr_index(self.tail_size)})" 2899 else: 2900 return f"welford_combine({var}, {{{mean}, {m2}, {weight}}})" 2901 elif reduction_type in ("argmin", "argmax"): 2902 assert src_dtype is not None 2903 cdtype = DTYPE_TO_CPP[src_dtype] 2904 n_src = self._get_num_vectors(src_dtype) 2905 n_idx = self._get_num_vectors(torch.int64) 2906 t_extra = "" 2907 arg_extra = "" 2908 if index is not None: 2909 assert horizontal_reduction is not None 2910 t_extra = f", {str(horizontal_reduction).lower()}" 2911 arg_extra = f", {index}" 2912 if self.tail_size: 2913 return ( 2914 f"{reduction_type}_combine_vec<{cdtype}, {n_src}, {n_idx}{t_extra}>" 2915 f"({var}, {next_value}{arg_extra}, {cexpr_index(self.tail_size)})" 2916 ) 2917 else: 2918 return f"{reduction_type}_combine_vec<{cdtype}, {n_src}, {n_idx}{t_extra}>({var}, {next_value}{arg_extra})" 2919 elif reduction_type == "any": 2920 return f"{var} | {next_value}" 2921 else: 2922 raise NotImplementedError 2923 2924 def indirect_assert(self, var, lower, upper, mask=None): 2925 assert isinstance(var, CppCSEVariable) 2926 assert var.dtype is not None 2927 if not var.is_vec: 2928 if isinstance(mask, CppCSEVariable) and mask.is_vec: 2929 mask = f"({mask}).all_masked()" 2930 return super().indirect_assert(var, lower, upper, mask) 2931 lower_scalar = lower 2932 upper_scalar = upper 2933 if lower: 2934 lower = f"{self._get_vec_type(var.dtype)}({lower})" 2935 if upper: 2936 upper = f"{self._get_vec_type(var.dtype)}({upper})" 2937 if lower and upper: 2938 cond = f"({lower} <= {var}) & ({var} < {upper})" 2939 cond_print = f"{lower_scalar} <= {var} < {upper_scalar}" 2940 elif lower: 2941 cond = f"{lower} <= {var}" 2942 cond_print = f"{lower_scalar} <= {var}" 2943 else: 2944 assert upper 2945 cond = f"{var} < {upper}" 2946 cond_print = f"{var} < {upper_scalar}" 2947 cond = f"{self._get_mask_type(var.dtype)}({cond})" 2948 if mask: 2949 if not mask.is_vec: 2950 mask = f"{self._get_mask_type(var.dtype)}({mask})" 2951 # We need not check when the mask is False 2952 cond = f"({cond}) | ~({mask})" 2953 if self.tail_size: 2954 cond = ( 2955 f"{self._get_mask_type(var.dtype)}::set({self._get_mask_type(var.dtype)}::from(1)" 2956 f", ({cond}), {cexpr_index(self.tail_size)})" 2957 ) 2958 cond = f"({cond}).all_masked()" 2959 return f'{self.assert_function}({cond}, "index out of bounds: {cond_print}")' 2960 2961 def get_to_dtype_expr(self, src, dtype, src_dtype): 2962 assert isinstance(src, CppCSEVariable) 2963 if not src.is_vec: 2964 return super().get_to_dtype_expr(src, dtype, src_dtype) 2965 src_cpp_type = DTYPE_TO_CPP[src_dtype] 2966 src_num_vectors = self._get_num_vectors(src_dtype) 2967 dst_cpp_type = DTYPE_TO_CPP[dtype] 2968 dst_num_vectors = self._get_num_vectors(dtype) 2969 expr = f"({src})" 2970 if src_dtype != torch.bool and dtype == torch.bool: 2971 expr = f"{self._get_mask_type(src_dtype)}::from<{src_cpp_type},{src_num_vectors}>({src})" 2972 elif src_dtype == torch.bool and dtype != torch.bool: 2973 expr = f"{src}.to<{dst_cpp_type},{dst_num_vectors}>()" 2974 elif src_dtype != dtype: 2975 if src_num_vectors == dst_num_vectors == 1: 2976 expr = f"at::vec::convert<{dst_cpp_type}>({src})" 2977 else: 2978 expr = f"at::vec::convert<{dst_cpp_type},{dst_num_vectors},{src_cpp_type},{src_num_vectors}>({src})" 2979 return expr 2980 2981 2982class CppTile2DKernel(CppVecKernel): 2983 """ 2984 A vector kernel that handles the 2d tiles with the tile size defined in `tiling_factor` on 2985 the inner-most loop level and one of the outer loop level (`outer_tiling_idx`). When the data 2986 tile is accessed in a contiguous way from the outer loop axis, a transposition is applied on the 2987 tile to make the access contiguous from the inner-most loop axis. Then, the same vectorization 2988 logic from its parent `CppVecKernel` is leveraged for load/store/compute. The transposed tile load 2989 and store are generated into kernel.preloads and kernel.poststores buffers. 2990 2991 The loop structure looks like below: 2992 for ... 2993 for i_outer ... 2994 for ... 2995 for inner_most ... 2996 // generated by CppTile2DKernel 2997 float tmp0[16*16]; at::vec::transpose_mxn<...>(tmp0, in_ptr0 + ..., ...); // into kernel.preloads 2998 float tmp1[16*16]; // into kernel.preloads 2999 for i_inner ... { // the kernel inner loop 3000 vectorized loads/compute/stores (e.g., load tmp0, store tmp1) // into kernel.loads/compute/stores 3001 } 3002 at::vec::transpose_mxn(out_ptr0 + ..., tmp1, ...) // into kernel.poststores 3003 for inner_most ... (tail) 3004 // generated by CppVecKernel 3005 ... 3006 for i_outer ... (tail) 3007 for ... 3008 for ... 3009 // generated by CppKernel 3010 ... 3011 """ 3012 3013 overrides = CppTile2DOverrides # type: ignore[assignment] 3014 3015 def __init__( 3016 self, 3017 args, 3018 num_threads, 3019 tiling_factor, 3020 tiling_indices, 3021 inner_tail_size=None, 3022 outer_tail_size=None, 3023 ): 3024 super().__init__( 3025 args, 3026 num_threads, 3027 tiling_factor, 3028 tiling_indices[1], 3029 inner_tail_size, 3030 ) 3031 self.tiling_indices = tiling_indices 3032 self.inner_tail_size = inner_tail_size 3033 self.outer_tail_size = outer_tail_size 3034 self.inner_num_elems = inner_tail_size if inner_tail_size else tiling_factor 3035 self.outer_num_elems = outer_tail_size if outer_tail_size else tiling_factor 3036 self.inner_is_tiling_idx = True 3037 3038 def inner_itervar(self): 3039 return sympy_index_symbol(f"{self.itervars[self.outer_idx]}_inner") 3040 3041 def need_vec_transpose(self, index): 3042 outer_var = self.itervars[self.outer_idx] 3043 inner_var = self.itervars[self.tiling_idx] 3044 outer_stride = stride_at_vec_range(index, outer_var, self.tiling_factor) 3045 inner_stride = stride_at_vec_range(index, inner_var, self.tiling_factor) 3046 return ( 3047 self._load_mask is None # TODO: support transposition with mask 3048 and outer_stride == 1 3049 and index.has(inner_var) 3050 and not inner_stride.has(inner_var) 3051 and not inner_stride.has(outer_var) 3052 ) 3053 3054 def gen_transposed_tile_load_store(self, name, var, index, is_store): 3055 # transposed tile load/store outside the kernel inner loop 3056 dtype = V.graph.get_dtype(name) 3057 factor = self.tiling_factor 3058 src = f"{var} + {cexpr_index(index)}" 3059 dst = "__place_holder__" 3060 ld_src = f"{cexpr_index(stride_at_vec_range(index, self.itervars[self.tiling_idx], self.tiling_factor))}" 3061 ld_dst = f"{cexpr_index(self.num_elems)}" 3062 if is_store: 3063 src, dst = dst, src 3064 ld_src, ld_dst = ld_dst, ld_src 3065 3066 need_define = True 3067 if self.inner_is_tiling_idx ^ is_store: 3068 M, N = self.inner_num_elems, self.outer_num_elems 3069 else: 3070 M, N = ( 3071 self.outer_num_elems, 3072 self.inner_num_elems, 3073 ) 3074 if (isinstance(M, sympy.Expr) and not M.is_number) or ( 3075 isinstance(N, sympy.Expr) and not N.is_number 3076 ): 3077 load_or_store = ( 3078 f"at::vec::transpose_mxn<{DTYPE_TO_CPP[dtype]}>" 3079 f"({src}, {ld_src}, {dst}, {ld_dst}, {cexpr_index(M)}, {cexpr_index(N)});" 3080 ) 3081 else: 3082 load_or_store = ( 3083 f"at::vec::transpose_mxn<{DTYPE_TO_CPP[dtype]},{cexpr_index(M)},{cexpr_index(N)}>" 3084 f"({src}, {ld_src}, {dst}, {ld_dst});" 3085 ) 3086 if is_store: 3087 tile_var = self.cse.newvar() 3088 elif load_or_store not in self.cse.cache: 3089 tile_var = self.cse.generate(self.preloads, load_or_store, write=False) 3090 else: 3091 need_define = False 3092 tile_var = self.cse.cache[load_or_store] 3093 3094 if need_define: 3095 define_line = f"alignas({factor}) {DTYPE_TO_CPP[dtype]} {tile_var}[{factor}*{factor}];" 3096 self.preloads.writeline(define_line) 3097 3098 load_or_store = load_or_store.replace("__place_holder__", str(tile_var)) 3099 if is_store: 3100 self.poststores.writeline(DeferredLine(name, load_or_store)) 3101 else: 3102 self.preloads.writeline(load_or_store) 3103 3104 return tile_var 3105 3106 def load(self, name: str, index: sympy.Expr): 3107 var = self.args.input(name) 3108 index = self.rename_indexing(index) 3109 3110 inner = self.inner_itervar() 3111 if self.need_vec_transpose(index): 3112 tile_var = self.gen_transposed_tile_load_store( 3113 name, var, index, is_store=False 3114 ) 3115 # vector load inside the kernel inner loop 3116 loadbuf = f"{tile_var} + {cexpr_index(inner * self.num_elems)}" 3117 dtype = V.graph.get_dtype(name) 3118 line = self._get_vec_load_line(loadbuf, 0, dtype) # type: ignore[arg-type] 3119 csevar = self.cse.generate(self.loads, line) 3120 csevar.update_on_args("load", (self, name, index), {}) 3121 assert isinstance(csevar, CppCSEVariable) 3122 csevar.is_vec = True 3123 return csevar 3124 else: 3125 new_index = self.transform_indexing(index) 3126 return super().load(name, new_index) 3127 3128 def store(self, name, index, value, mode=None): 3129 assert "buf" in name 3130 var = self.args.output(name) 3131 3132 inner = self.inner_itervar() 3133 index = self.rename_indexing(index) 3134 assert mode is None 3135 if self.need_vec_transpose(index): 3136 tile_var = self.gen_transposed_tile_load_store( 3137 name, var, index, is_store=True 3138 ) 3139 # vector store inside the kernel inner loop 3140 storebuf = f"{tile_var} + {cexpr_index(inner * self.num_elems)}" 3141 if self.tail_size or V.graph.get_dtype(name) in DTYPE_LOWP_FP + [ 3142 torch.uint8, 3143 torch.int8, 3144 ]: 3145 line = f"{value}.store({storebuf}, {cexpr_index(self.num_elems)});" 3146 else: 3147 line = f"{value}.store({storebuf});" 3148 self.stores.writeline(DeferredLine(name, line)) 3149 else: 3150 new_index = self.transform_indexing(index) 3151 super().store(name, new_index, value, mode) 3152 3153 def codegen_inner_loops(self, code): 3154 inner = self.inner_itervar() 3155 if self.inner_is_tiling_idx: 3156 code.writeline( 3157 f"for (long {inner} = 0; {inner} < {cexpr_index(self.outer_num_elems)}; {inner}++)" 3158 ) 3159 else: 3160 code.writeline( 3161 f"for (long {inner} = 0; {inner} < {cexpr_index(self.inner_num_elems)}; {inner}++)" 3162 ) 3163 3164 def set_ranges(self, group, reduction_group): 3165 vars = super().set_ranges(group, reduction_group) 3166 # do vertical reduction as the tail loop 3167 self.outer_idx, self.tiling_idx = ( 3168 self.tiling_indices 3169 if self.tiling_indices[1] < self.reduction_depth 3170 else reversed(self.tiling_indices) 3171 ) 3172 if self.tiling_idx == self.tiling_indices[0]: 3173 self.tail_size = self.outer_tail_size 3174 self.num_elems = self.outer_num_elems 3175 self.inner_is_tiling_idx = False 3176 else: 3177 self.tail_size = self.inner_tail_size 3178 self.num_elems = self.inner_num_elems 3179 self.inner_is_tiling_idx = True 3180 return vars 3181 3182 def transform_indexing(self, index: sympy.Expr) -> sympy.Expr: 3183 return self.scale_index_with_offset( 3184 index, 3185 itervar_idx=self.outer_idx, 3186 offset=self.inner_itervar(), 3187 ) 3188 3189 3190def get_loop_body_lowp_fp(_body: LoopBody) -> Tuple[Optional[torch.dtype], bool]: 3191 """ 3192 Returns the low precision data type (torch.float16/torch.bfloat16) contained in the nodes 3193 and if all the nodes can codegen with this data type without converting to float. 3194 Otherwise returns None and True. 3195 """ 3196 sub_blocks = [_body.root_block] + list(_body.subblocks.values()) 3197 3198 _lowp_fp_type: Optional[torch.dtype] = None 3199 _use_fp32 = False 3200 for sub_block in sub_blocks: 3201 for _node in sub_block.graph.nodes: 3202 if _node.op == "placeholder" or _node.target in ( 3203 "get_index", 3204 "index_expr", 3205 ): 3206 continue 3207 3208 # Fast path if all operations can support bf16/fp16 without converting to fp32 3209 if _node.target not in [ 3210 "load", 3211 "store", 3212 "abs", 3213 "neg", 3214 "output", 3215 ]: 3216 _use_fp32 = True 3217 3218 if hasattr(_node, "meta") and _node.meta: 3219 assert OptimizationContext.key in _node.meta 3220 opt_ctx: OptimizationContext = _node.meta[OptimizationContext.key] 3221 if not opt_ctx.dtype or opt_ctx.dtype not in DTYPE_LOWP_FP: 3222 _use_fp32 = True 3223 elif _lowp_fp_type is not None: 3224 if _lowp_fp_type != opt_ctx.dtype: 3225 warnings.warn("bf16 and fp16 are mixed in the scheduler node.") 3226 else: 3227 _lowp_fp_type = opt_ctx.dtype 3228 else: 3229 _use_fp32 = True 3230 3231 return _lowp_fp_type, _use_fp32 3232 3233 3234class TilingSelect: 3235 """ 3236 Implement the heuristic to select the tiling factors and tiling indices. 3237 In the future, we can implement advanced heuristic in a subclass. 3238 """ 3239 3240 def __init__(self): 3241 super().__init__() 3242 3243 def select_tiling( 3244 self, 3245 fn_list, 3246 var_sizes_list, 3247 ) -> Tuple[List[int], List[int]]: 3248 # TODO(jgong5): support alternative tiling factors and data types 3249 loop_bodies = _get_loop_body(fn_list) 3250 all_dtypes = _get_dtype_from_loopbodies(loop_bodies) 3251 assert all_dtypes 3252 if any(dtype not in VECTORIZABLE_DTYPES for dtype in all_dtypes): 3253 return [], [] 3254 dtype = torch.float 3255 _lowp_fp_dtype = get_loop_body_lowp_fp(loop_bodies[0])[0] 3256 if _lowp_fp_dtype and all( 3257 (get_loop_body_lowp_fp(loop_body)[0] == _lowp_fp_dtype) 3258 for loop_body in loop_bodies[1:] 3259 ): 3260 dtype = _lowp_fp_dtype 3261 3262 tiling_factor = cpu_vec_isa.pick_vec_isa().nelements(dtype=dtype) 3263 tiling_indices = self._select_tiling_indices( 3264 fn_list, var_sizes_list, tiling_factor 3265 ) 3266 3267 if tiling_indices: 3268 group, reduction_group = max( 3269 var_sizes_list, key=lambda sizes: len(sizes[1]) 3270 ) 3271 call_ranges = tuple(group) + tuple(reduction_group) 3272 3273 if config.cpp.enable_tiling_heuristics: 3274 3275 def _try_get_stride( 3276 index, 3277 itervars, 3278 tiling_factor, 3279 tiling_indices, 3280 ): 3281 itervar = itervars[tiling_indices[0]] 3282 stride = stride_at_vec_range(index, itervar, tiling_factor) 3283 return stride if stride.is_number else None 3284 3285 def _update_negative_op_count( 3286 node_name, non_contig_indexing_op_counter 3287 ): 3288 if node_name not in non_contig_indexing_op_counter: 3289 non_contig_indexing_op_counter[node_name] = 1 3290 else: 3291 non_contig_indexing_op_counter[node_name] += 1 3292 3293 def _is_valid_indices( 3294 itervars, 3295 tiling_indices, 3296 ): 3297 return ( 3298 len(tiling_indices) == 1 3299 and len(itervars) > 0 3300 and ( 3301 tiling_indices[0] 3302 if tiling_indices[0] >= 0 3303 else tiling_indices[0] + len(itervars) 3304 ) 3305 < len(itervars) 3306 ) 3307 3308 itervars = [ 3309 sympy_index_symbol_with_prefix(SymT.XBLOCK, n) 3310 for n in range(len(call_ranges)) 3311 ] 3312 reduction_depth = len(group) 3313 vars, reduction_vars = ( 3314 itervars[:reduction_depth], 3315 itervars[reduction_depth:], 3316 ) 3317 op_counter: Dict[str, int] = {} 3318 # ops may cause overhead with vectorization, like non-contiguous 3319 # index_expr, load, store 3320 non_contig_indexing_op_counter: Dict[str, int] = {} 3321 for _body in loop_bodies: 3322 sub_blocks = [_body.root_block] + list(_body.subblocks.values()) 3323 for sub_block in sub_blocks: 3324 for _node in sub_block.graph.nodes: 3325 if _node.target in ["index_expr", "load", "store"]: 3326 # get the index and replace prefix from z to x 3327 arg_idx = 1 if _node.target == "index_expr" else 2 3328 index = sub_block.body.indexing_from_args( 3329 (vars, reduction_vars) 3330 )[_node.args[arg_idx].args[0]] 3331 if _is_valid_indices(itervars, tiling_indices): 3332 stride = _try_get_stride( 3333 index, itervars, tiling_factor, tiling_indices 3334 ) 3335 if ( 3336 stride is None 3337 if _node.target == "index_expr" 3338 else stride not in [0, 1] 3339 ): 3340 _update_negative_op_count( 3341 _node.target, non_contig_indexing_op_counter 3342 ) 3343 if isinstance(_node.target, str) and not ( 3344 _node.target.startswith("masked_subblock") 3345 or _node.target 3346 in ["ops", "output", "constant", "get_index"] 3347 ): 3348 if _node.target not in op_counter: 3349 op_counter[_node.target] = 1 3350 else: 3351 op_counter[_node.target] += 1 3352 3353 op_num = sum(op_counter.values()) 3354 non_contig_indexing_op_num = sum( 3355 non_contig_indexing_op_counter.values() 3356 ) 3357 threshold = 0.08 3358 if op_num > 0 and non_contig_indexing_op_num / op_num >= threshold: 3359 # Too many non-contiguous load/store/index_expr which hurts the 3360 # vectorization performance. Disable vectorization when exceeding 3361 # the threshold. 3362 return [], [] 3363 3364 if ( 3365 not reduction_group 3366 and group 3367 and len(tiling_indices) == 1 3368 and not has_free_symbols( 3369 [ 3370 group[tiling_indices[0]], 3371 ] 3372 ) 3373 and group[tiling_indices[0]] < tiling_factor / 2 3374 ): 3375 # For case of Multi Thread AMP Static shape of pyhpc_isoneutral_mixing, 3376 # the inner loop range doesn't have enough elements to do vectorization 3377 # explicitly and found that `#pragma GCC ivdep` has better performance than 3378 # `#pragma omp simd simdlen(8)`. Disable vectorization for this case. 3379 # <TODO> Leslie: maybe we can always disable vectorization when loop range is less 3380 # than tiling factor and enable `#pragma omp simd simdlen(8)` for scalar kernel 3381 # when needed. 3382 return [], [] 3383 3384 if dtype in DTYPE_LOWP_FP: 3385 # For lower precision data type, if the call_range is not long enough, 3386 # use tiling_factor // 2 for better performance 3387 factor_lowp = cpu_vec_isa.pick_vec_isa().nelements(dtype=dtype) 3388 for tiling_indice in tiling_indices: 3389 if tiling_indice < 0: 3390 tiling_indice = tiling_indice + len(call_ranges) 3391 if tiling_indice < 0 or tiling_indice >= len(call_ranges): 3392 continue 3393 if has_free_symbols(call_ranges): 3394 call_range = V.graph.sizevars.size_hint( 3395 call_ranges[tiling_indice], fallback=0 3396 ) 3397 if call_range < factor_lowp: 3398 V.graph.sizevars.guard_lt(call_range, factor_lowp) 3399 tiling_factor = factor_lowp // 2 3400 break 3401 elif call_ranges[tiling_indice] < factor_lowp: 3402 tiling_factor = factor_lowp // 2 3403 break 3404 3405 if len(tiling_indices) == 1: 3406 return [tiling_factor], tiling_indices 3407 if len(tiling_indices) == 2: 3408 return [tiling_factor, tiling_factor], tiling_indices 3409 return [], [] 3410 3411 def _select_tiling_indices( 3412 self, 3413 fn_list, 3414 var_sizes_list, 3415 tiling_factor, 3416 ): 3417 all_index = [] 3418 for fn, var_sizes in zip(fn_list, var_sizes_list): 3419 rw = dependencies.extract_read_writes(fn, *var_sizes) 3420 all_index += [dep.index for dep in itertools.chain(rw.reads, rw.writes)] 3421 contig_vars = set() 3422 contig_vars_list = [] 3423 non_contig_stride_const = set() 3424 non_contig_stride_other = set() 3425 for index in all_index: 3426 for var in index.free_symbols: 3427 if not re.search(r"^d\d+$", var.name): 3428 continue 3429 stride = stride_at_vec_range(index, var, tiling_factor) 3430 if stride == 0: 3431 continue 3432 elif stride == 1: 3433 contig_vars.add(int(var.name[1:])) 3434 contig_vars_list.append(int(var.name[1:])) 3435 elif all(symbol_is_type(s, SymT.SIZE) for s in stride.free_symbols): 3436 non_contig_stride_const.add(int(var.name[1:])) 3437 else: 3438 non_contig_stride_other.add(int(var.name[1:])) 3439 contig_only = contig_vars - non_contig_stride_const - non_contig_stride_other 3440 group, reduction_group = max(var_sizes_list, key=lambda sizes: len(sizes[1])) 3441 num_itervars = len(group) + len(reduction_group) 3442 if len(contig_vars) == 0: 3443 # no contiguous vars 3444 return [num_itervars - 1] 3445 if contig_only: 3446 return sorted(contig_only)[-1:] 3447 contig_and_const_stride = ( 3448 contig_vars & non_contig_stride_const 3449 ) - non_contig_stride_other 3450 contig_vars_sorted = sorted(contig_vars) 3451 if ( 3452 len(contig_vars_sorted) == 2 3453 and contig_vars_sorted[-1] in contig_and_const_stride 3454 and contig_vars_sorted[-1] == num_itervars - 1 3455 ): 3456 return contig_vars_sorted 3457 return sorted(contig_vars_sorted, key=contig_vars_list.count)[-1:] 3458 3459 3460class CppKernelProxy(CppKernel): 3461 def __init__(self, kernel_group): 3462 super().__init__(kernel_group.args, kernel_group.ws.num_threads) 3463 self.kernel_group = kernel_group 3464 self.loop_nest = None 3465 self.call_ranges = None 3466 self.picked_vec_isa: cpu_vec_isa.VecISA = cpu_vec_isa.pick_vec_isa() 3467 3468 def data_type_propagation(self, nodes): 3469 for _node in nodes: 3470 assert isinstance(_node, SchedulerNode) 3471 DataTypePropagation.propagate_scheduler_node(_node) 3472 3473 # Check if all the nodes of a given fx graph can support BF16/FP16 3474 def is_lowp_fp_scheduler(self, scheduler_node: SchedulerNode): 3475 if not isinstance(scheduler_node._body, LoopBody): 3476 return True 3477 # Propagate the dtype to check if all the fx node is bf16/fp16 3478 DataTypePropagation.propagate_scheduler_node(scheduler_node) 3479 return ( 3480 get_loop_body_lowp_fp(scheduler_node._body)[0] is not None 3481 and not get_loop_body_lowp_fp(scheduler_node._body)[1] 3482 ) 3483 3484 def legalize_lowp_fp_dtype_loopbody(self, loop_body: LoopBody): 3485 def add_to_dtype(sub_graph: torch.fx.Graph): 3486 def is_lowp_fp_load(node: torch.fx.Node): 3487 if node.target not in ["load"]: 3488 return False 3489 assert len(node.args) == 3 3490 load_dtype = V.graph.get_dtype(node.args[1]) # type: ignore[arg-type] 3491 return load_dtype in DTYPE_LOWP_FP 3492 3493 def is_lowp_fp_store(node: torch.fx.Node): 3494 if node.target != "store": 3495 return False 3496 _, store_var, _, _, _ = node.args 3497 store_dtype = V.graph.get_dtype(store_var) # type: ignore[arg-type] 3498 return store_dtype in DTYPE_LOWP_FP 3499 3500 sub_graph_nodes = list(sub_graph.nodes) 3501 to_lowp_fp_legalized_nodes = [] 3502 for _node in sub_graph_nodes: 3503 if is_lowp_fp_load(_node): 3504 # No need to promote to float if all users are direct stores 3505 if all(user.target == "store" for user in _node.users): 3506 continue 3507 ops = _node.args[0] 3508 with sub_graph.inserting_after(_node): 3509 to_type_node = sub_graph.call_method( 3510 "to_dtype", args=(ops, _node, torch.float) 3511 ) 3512 to_type_node_args = to_type_node.args 3513 _node.replace_all_uses_with(to_type_node) 3514 to_type_node.args = to_type_node_args 3515 metrics.cpp_to_dtype_count += 1 3516 elif is_lowp_fp_store(_node): 3517 ops, name, _, value_var, _ = _node.args 3518 # No need to promote to float if it is a user of a load which are all directly stored 3519 if value_var.target == "load" and all( 3520 user.target == "store" for user in value_var.users 3521 ): 3522 continue 3523 dtype = V.graph.get_dtype(name) 3524 with sub_graph.inserting_before(_node): 3525 to_type_node = sub_graph.call_method( 3526 "to_dtype", args=(ops, value_var, dtype) 3527 ) 3528 _node.replace_input_with(value_var, to_type_node) 3529 metrics.cpp_to_dtype_count += 1 3530 elif _node.target == "reduction": 3531 ( 3532 ops, 3533 dtype, 3534 src_dtype, 3535 reduction_type, 3536 value, 3537 ) = _node.args 3538 if src_dtype in DTYPE_LOWP_FP: 3539 # Since we always convert the load/store value to float if the tensor is bfloat16/float16. 3540 # Therefore, the reduction should never work with bfloat16/float16 value. Hence, we update 3541 # the bfloat16/float16 reduction by 3542 # 1) updating the src_dtype to float 3543 # and 2) updating the dtype to float if it is bfloat16/float16. 3544 assert dtype in [ 3545 torch.float, 3546 torch.bfloat16, 3547 torch.float16, 3548 torch.int64, 3549 ] 3550 _node.args = ( 3551 ops, 3552 torch.float if dtype in DTYPE_LOWP_FP else dtype, 3553 torch.float, 3554 reduction_type, 3555 value, 3556 ) 3557 elif _node.target == "to_dtype" and _node.args[-1] in DTYPE_LOWP_FP: 3558 (ops, x, _) = _node.args 3559 # The legalization always loads the BF16/FP16 tensor as FP32 for computation 3560 # and converts back to BF16/FP16 after the computation. 3561 # Hence, there should be no computation w/ BF16/FP16. 3562 # Therefore, we update the to_dtype by replacing the bf16/fp16 dtype with fp32. 3563 # Save the legalized to_dtype node for the elimination(eliminate_to_dtype step): 3564 # 1) Eliminate the redundant to_dtype node if we have a pattern as follows: 3565 # graph(): 3566 # %lowp_fp_legalized = call_method[target=to_dtype](args = (%ops, %input, torch.float)) 3567 # %to_dtype2 = call_method[target=to_dtype](args = (%ops, %lowp_fp_legalized, torch.bfloat16/float16)) 3568 # Regarding the first to_dtype, it is redundant because 3569 # the second to_type also converts to the torch.bfloat16/torch.float16. 3570 # Hence, we remove the first to_type. 3571 to_lowp_fp_legalized_nodes.append(_node) 3572 _node.args = (ops, x, torch.float) 3573 else: 3574 pass 3575 3576 def eliminate_to_dtype(sub_graph: torch.fx.Graph): 3577 def _eliminate_duplicate_to_node(sub_graph: torch.fx.Graph): 3578 # Eliminate the redundant to_dtype node. Let's consider a pattern as follows: 3579 # graph(): 3580 # %to_dtype1 = call_method[target=to_dtype](args = (%ops, %input, torch.float), kwargs = {}) 3581 # %to_dtype2 = call_method[target=to_dtype](args = (%ops, %to_dtype1, torch.float), kwargs = {}) 3582 # Regarding the first to_dtype, it is redundant because the second to_type also converts to the 3583 # torch.float. Hence, we remove the first to_type 3584 def _used_by_to(to_node: torch.fx.Node): 3585 return all(usr.target == "to_dtype" for usr in to_node.users) 3586 3587 all_to_nodes = [ 3588 node for node in sub_graph.nodes if node.target == "to_dtype" 3589 ] 3590 all_to_nodes_and_users = [ 3591 {node: node.users} for node in all_to_nodes if _used_by_to(node) 3592 ] 3593 for node_users in all_to_nodes_and_users: 3594 for node, users in node_users.items(): 3595 if node in sub_graph.nodes and ( 3596 all(usr.args[-1] == node.args[-1] for usr in users) 3597 or ( 3598 node in to_lowp_fp_legalized_nodes 3599 and all( 3600 usr.args[-1] in DTYPE_LOWP_FP for usr in users 3601 ) 3602 ) 3603 ): 3604 val_node = node.all_input_nodes[-1] 3605 node.replace_all_uses_with(val_node) 3606 sub_graph.erase_node(node) 3607 3608 # For debug mode, the graph of LoopBody will attach a new GraphModule as 3609 # owning_module for debugging while the release mode will not. The lint will 3610 # check whether the graph has owning_module to decide if it needs to check 3611 # call_module. LoopBody might contain get_index as a module call. But it 3612 # is just a function. Hence, it cannot pass the lint check for debug mode. 3613 # We bypass the check if the owning_module is None. Eventually, we should call 3614 # get_index via call_function but not call_module. 3615 if sub_graph.owning_module is None: 3616 sub_graph.lint() 3617 3618 _eliminate_duplicate_to_node(sub_graph) 3619 3620 eliminate_to_dtype(sub_graph) 3621 3622 sub_blocks = [loop_body.root_block] + list(loop_body.subblocks.values()) 3623 for sub_block in sub_blocks: 3624 add_to_dtype(sub_block.graph) 3625 3626 def legalize_lowp_fp_dtype(self, nodes): 3627 if all( 3628 isinstance(_node, SchedulerNode) and self.is_lowp_fp_scheduler(_node) 3629 for _node in nodes 3630 ): 3631 # Mark the load node to load bf16/fp16 3632 for _node in nodes: 3633 sub_blocks = [_node._body.root_block] + list( 3634 _node._body.subblocks.values() 3635 ) 3636 for sub_block in sub_blocks: 3637 for fx_node in sub_block.graph.nodes: 3638 if fx_node.target in ["load", "store"]: 3639 assert fx_node.meta 3640 assert OptimizationContext.key in fx_node.meta 3641 opt_ctx: OptimizationContext = fx_node.meta[ 3642 OptimizationContext.key 3643 ] 3644 assert opt_ctx.dtype in DTYPE_LOWP_FP 3645 3646 # Bypass the legalization as the kernel can run with bf16/fp16 directly 3647 return 3648 3649 for _node in nodes: 3650 assert isinstance(_node, SchedulerNode) 3651 assert isinstance(_node._body, LoopBody) 3652 body: LoopBody = _node._body 3653 if not body.is_memory_copy(): 3654 self.legalize_lowp_fp_dtype_loopbody(body) 3655 3656 def codegen_functions(self, fn_list, var_sizes_list): 3657 assert len(fn_list) == len(var_sizes_list) 3658 kernel_group = self.kernel_group 3659 group, reduction_group = max(var_sizes_list, key=lambda sizes: len(sizes[1])) 3660 3661 self.set_ranges(group, reduction_group) 3662 3663 def codegen_kernel(cls, *args): 3664 with kernel_group.new_kernel(cls, *args) as kernel: 3665 # Ugly hack to maintain the metrics kernel count since 3666 # we only count in CppKernelProxy, not those contained in it 3667 metrics.generated_kernel_count -= 1 3668 3669 run(kernel) 3670 return kernel 3671 3672 def run(kernel): 3673 vars, reduction_vars = kernel.set_ranges(group, reduction_group) 3674 in_suffix = False 3675 for fn, var_sizes in zip(fn_list, var_sizes_list): 3676 if var_sizes in [ 3677 (group, reduction_group), 3678 (tuple(itertools.chain(group, reduction_group)), ()), 3679 ]: 3680 assert not in_suffix 3681 fn(vars, reduction_vars) 3682 else: 3683 in_suffix = True 3684 assert var_sizes == ( 3685 group, 3686 (), 3687 ), f"unexpected group: {var_sizes} != {group}, {reduction_group}" 3688 # we can fuse in some extra pointwise into the suffix 3689 with kernel.write_to_suffix(): 3690 fn(vars, ()) 3691 3692 scalar_kernel = codegen_kernel(CppKernel) 3693 V.graph.removed_buffers |= scalar_kernel.removed_buffers 3694 V.graph.inplaced_to_remove |= scalar_kernel.inplaced_to_remove 3695 self.loop_nest = LoopNestWithSplit.build(scalar_kernel) 3696 3697 if not self.picked_vec_isa: 3698 return 3699 3700 if not self.itervars: 3701 # not a loop 3702 return 3703 3704 # Kernels share the same global contexts like V.graph.wrapper_code, V.kernel.args. 3705 # But the generated scalar kernel has updated these global contexts. Hence, the other kernels 3706 # should not do this again to avoid context conflict. By now, we only control the 3707 # config.inplace_buffers. In the future, we could maintain more contexts. 3708 with torch._inductor.config.patch(inplace_buffers=False): 3709 tiling_select = TilingSelect() 3710 tiling_factors, tiling_indices = tiling_select.select_tiling( 3711 fn_list, var_sizes_list 3712 ) 3713 assert len(tiling_factors) == len(tiling_indices) 3714 # <TODO> This should be removed after full support for vectorization is implemented. 3715 could_masked_vec = True 3716 all_dtypes = _get_dtype_from_loopbodies(_get_loop_body(fn_list)) 3717 if any(dtype not in MASKED_VECTORIZABLE_DTYPES for dtype in all_dtypes): 3718 # can be removed after masked vectorizable dtype are same with vectorizable dtype 3719 could_masked_vec = False 3720 3721 if len(tiling_indices) == 1: 3722 vec_kernel = codegen_kernel( 3723 CppVecKernel, tiling_factors[0], tiling_indices[0] 3724 ) 3725 metrics.generated_cpp_vec_kernel_count += 1 3726 main_loop, tail_loop = self.loop_nest.split_with_tiling( 3727 tiling_indices[0], factor=tiling_factors[0] 3728 ) 3729 main_loop.set_kernel(vec_kernel) 3730 main_loop.simd_vec = True 3731 if config.cpp.enable_loop_tail_vec and could_masked_vec: 3732 tail_loop.steps = tail_loop.size - tail_loop.offset 3733 masked_vec_kernel = codegen_kernel( 3734 CppVecKernel, 3735 tiling_factors[0], 3736 tiling_indices[0], 3737 tail_loop.steps, 3738 ) 3739 tail_loop.set_kernel(masked_vec_kernel) 3740 tail_loop.simd_vec = True 3741 else: 3742 tail_loop.set_kernel(scalar_kernel) 3743 tail_loop.simd_omp = True 3744 # We chop the loop into two cubes by the nelements - main loop and tail loop. 3745 # Regarding the main loop, it is straightforward that it could be vectorized with 3746 # nelements. But for the tail loop, it still could be vectorized. For example, 3747 # if the nelements is 8(256bits), then the tail loop still could be vectorized 3748 # as 4(128bits). 3749 tail_loop.simd_nelements = tiling_factors[0] // 2 3750 elif len(tiling_indices) == 2: 3751 assert ( 3752 tiling_indices[1] == len(self.itervars) - 1 3753 and tiling_factors[0] == tiling_factors[1] 3754 ) 3755 3756 metrics.generated_cpp_vec_kernel_count += 2 3757 outer_main_loop, outer_tail_loop = self.loop_nest.split_with_tiling( 3758 tiling_indices[0], factor=tiling_factors[0] 3759 ) 3760 ( 3761 inner_main_loop, 3762 inner_tail_loop, 3763 ) = outer_main_loop.split_with_tiling( 3764 tiling_indices[1] - tiling_indices[0], factor=tiling_factors[0] 3765 ) 3766 tile2d_kernel = codegen_kernel( 3767 CppTile2DKernel, tiling_factors[0], tiling_indices 3768 ) 3769 inner_main_loop.set_kernel(tile2d_kernel) 3770 3771 if config.cpp.enable_loop_tail_vec and could_masked_vec: 3772 ( 3773 inner_main_loop_of_outer_tail_loop, 3774 inner_tail_loop_of_outer_tail_loop, 3775 ) = outer_tail_loop.split_with_tiling( 3776 tiling_indices[1] - tiling_indices[0], factor=tiling_factors[0] 3777 ) 3778 3779 for tail_loop in ( 3780 inner_tail_loop, 3781 outer_tail_loop, 3782 inner_tail_loop_of_outer_tail_loop, 3783 ): 3784 tail_loop.steps = tail_loop.size - tail_loop.offset 3785 3786 for tail_loop, inner_tail_size, outer_tail_size in ( 3787 (inner_tail_loop, inner_tail_loop.steps, None), 3788 ( 3789 inner_main_loop_of_outer_tail_loop, 3790 None, 3791 outer_tail_loop.steps, 3792 ), 3793 ( 3794 inner_tail_loop_of_outer_tail_loop, 3795 inner_tail_loop_of_outer_tail_loop.steps, 3796 outer_tail_loop.steps, 3797 ), 3798 ): 3799 masked_tile2d_kernel = codegen_kernel( 3800 CppTile2DKernel, 3801 tiling_factors[0], 3802 tiling_indices, 3803 inner_tail_size, 3804 outer_tail_size, 3805 ) 3806 tail_loop.set_kernel(masked_tile2d_kernel) 3807 else: 3808 vec_kernel = codegen_kernel( 3809 CppVecKernel, tiling_factors[0], tiling_indices[0] 3810 ) 3811 inner_tail_loop.set_kernel(vec_kernel) 3812 3813 outer_tail_loop.set_kernel(scalar_kernel) 3814 3815 def codegen_loop_bodies(self, loop_bodies, var_sizes_list): 3816 for body in loop_bodies: 3817 self.legalize_lowp_fp_dtype_loopbody(body) 3818 DataTypePropagation.propagate_loopbody(body) 3819 self.codegen_functions(loop_bodies, var_sizes_list) 3820 3821 def codegen_nodes(self, nodes: List[SchedulerNode]): 3822 # Legalize BF16 node by adding to_dtype explicitly 3823 self.legalize_lowp_fp_dtype(nodes) 3824 self.data_type_propagation(nodes) 3825 assert len(nodes) >= 1 3826 3827 def fn(node, *index_vars): 3828 node.decide_inplace_update() 3829 node.mark_run() 3830 if isinstance(V.kernel, NullKernelHandler): 3831 return node._body(*index_vars) 3832 else: 3833 return node.codegen(index_vars) 3834 3835 fn_list = [functools.partial(fn, node) for node in nodes] 3836 3837 if ( 3838 isinstance(V.local_buffer_context, LocalBufferContext) 3839 and V.local_buffer_context.local_buffers 3840 ): 3841 3842 def wrap_fn(fn): 3843 wrapped_fn = V.local_buffer_context.localize_function( 3844 fn, 3845 ) 3846 wrapped_fn.original_fn = fn 3847 return wrapped_fn 3848 3849 fn_list = [wrap_fn(fn) for fn in fn_list] 3850 3851 var_sizes_list = [node.group[1] for node in nodes] 3852 self.codegen_functions(fn_list, var_sizes_list) 3853 3854 def codegen_loops(self, code, worksharing): 3855 self.codegen_loops_impl(self.loop_nest, code, worksharing) 3856 3857 3858class OuterLoopFusedKernel(CppKernel): 3859 def __init__(self, kernel_group): 3860 super().__init__(kernel_group.args, kernel_group.ws.num_threads) 3861 self.inner: List[LoopLevel] = [] 3862 3863 def decide_parallel_depth(self, max_parallel_depth, threads) -> int: 3864 kernels_parallel_depth = [] 3865 nested_kernels: List[List[CppKernel]] = [ 3866 loop.get_kernels() for loop in self.inner 3867 ] 3868 for kernels in nested_kernels: 3869 # For any ScalarKernel, VecKernel, or Tile2DKernel, 3870 # they should all have the same call_ranges 3871 call_ranges = kernels[0].call_ranges 3872 assert call_ranges is not None 3873 assert all(kernel.call_ranges == call_ranges for kernel in kernels) 3874 kernels_parallel_depth.append( 3875 kernels[0].decide_parallel_depth(len(call_ranges), threads) 3876 ) 3877 return min( 3878 max_parallel_depth, 3879 max(kernels_parallel_depth), 3880 ) 3881 3882 3883class ReasonFusedNodes(Enum): 3884 SAME_VARS_REDUCE = "same_vars_reduce" 3885 COMPATIBLE_REDUCTION = "compatible_reduction" 3886 COMPATIBLE_RANGES_NO_REDUCTION = "compatible_ranges_no_reduction" 3887 3888 3889class CppScheduling(BaseScheduling): 3890 # ctypes limits the number of args to 1024, refer to: 3891 # https://github.com/python/cpython/commit/a285af7e626d1b81cf09f8b2bf7656f100bc1237 3892 # We set a conservative threshold here. 3893 MAX_FUSED_KERNEL_ARGS_NUM = 500 3894 backend_features = dict.fromkeys( 3895 [ 3896 BackendFeature.INPLACE_BUFFERS, 3897 BackendFeature.REDUCE_TO_SINGLE_ELEMENT, 3898 ] 3899 ) 3900 3901 @classmethod 3902 def get_backend_features(cls, device: torch.device): 3903 return cls.backend_features 3904 3905 def __init__(self, scheduler): 3906 super().__init__() 3907 self.scheduler = scheduler 3908 if scheduler: 3909 self.reset_kernel_group() 3910 self._ready_to_flush = False 3911 3912 def _set_flush_status(self, status: bool): 3913 self._ready_to_flush = status 3914 3915 def group_fn(self, sizes): 3916 return tuple(tuple(map(V.graph.sizevars.simplify, s)) for s in sizes) 3917 3918 def reset_kernel_group(self): 3919 from .cpp_wrapper_cpu import CppWrapperCpu 3920 3921 self.kernel_group: Union[CppWrapperKernelGroup, KernelGroup] 3922 if isinstance(V.graph.wrapper_code, CppWrapperCpu): 3923 self.kernel_group = CppWrapperKernelGroup() 3924 else: 3925 self.kernel_group = KernelGroup() 3926 3927 def fuse(self, node1, node2): 3928 if node1.is_foreach() or node2.is_foreach(): 3929 return ForeachKernelSchedulerNode.fuse(node1, node2) 3930 elif node1.is_template(): 3931 assert not node2.is_template() 3932 return FusedSchedulerNode.fuse(node1, node2) 3933 else: 3934 if ( 3935 self._why_fuse_nodes(node1, node2) 3936 == ReasonFusedNodes.COMPATIBLE_RANGES_NO_REDUCTION 3937 ): 3938 assert isinstance(node1, (SchedulerNode, FusedSchedulerNode)) 3939 assert isinstance(node2, (SchedulerNode, FusedSchedulerNode)) 3940 3941 _, (vars1, reduce1) = node1.group 3942 _, (vars2, reduce2) = node2.group 3943 assert reduce1 == () and reduce2 == (), (reduce1, reduce2) 3944 3945 def get_indexing_ranges_exprs(node): 3946 if isinstance(node, FusedSchedulerNode): 3947 assert len(node.snodes) > 0, node.snodes 3948 var_ranges = None 3949 indexing_exprs = set() 3950 for snode in node.snodes: 3951 v, exprs = get_indexing_ranges_exprs(snode) 3952 if var_ranges is None: 3953 var_ranges = v 3954 assert var_ranges == v, (var_ranges, v, node.snodes) 3955 indexing_exprs.update(exprs) 3956 return var_ranges, list(indexing_exprs) 3957 else: 3958 assert isinstance(node, SchedulerNode) 3959 comp_buffer = node.node 3960 assert isinstance(comp_buffer, ir.ComputedBuffer) 3961 _, body, _ = comp_buffer.get_default_sizes_body() 3962 return body.var_ranges, list(body.indexing_exprs.values()) 3963 3964 node_to_recomp = node1 if len(vars1) < len(vars2) else node2 3965 assert isinstance(node_to_recomp, SchedulerNode) 3966 3967 ref_node = node2 if len(vars1) < len(vars2) else node1 3968 3969 extra_indexing_constraints = get_indexing_ranges_exprs(ref_node) 3970 3971 node_to_recomp.recompute_size_and_body( 3972 extra_indexing_constraints=extra_indexing_constraints 3973 ) 3974 3975 _, (vars1, _) = node1.group 3976 _, (vars2, _) = node2.group 3977 assert vars1 == vars2, (vars1, vars2) 3978 return FusedSchedulerNode.fuse(node1, node2) 3979 elif self.can_fuse_vertical_outer_loop(node1, node2): 3980 return OuterLoopFusedSchedulerNode.fuse( 3981 node1, node2, self._get_outer_loop_fusion_depth(node1, node2) 3982 ) 3983 else: 3984 return FusedSchedulerNode.fuse(node1, node2) 3985 3986 def _why_fuse_nodes(self, node1, node2) -> Optional[ReasonFusedNodes]: 3987 _, (vars1, reduce1) = node1.group 3988 _, (vars2, reduce2) = node2.group 3989 3990 if vars1 == vars2 and reduce1 == reduce2: 3991 return ReasonFusedNodes.SAME_VARS_REDUCE 3992 if reduce1 == () and vars1 == vars2 + reduce2: 3993 return ReasonFusedNodes.COMPATIBLE_REDUCTION 3994 if self._can_fuse_nodes_with_compatible_ranges(node1, node2): 3995 return ReasonFusedNodes.COMPATIBLE_RANGES_NO_REDUCTION 3996 # TODO(jansel): allow fusion pointwise (vars1, ()) suffix? 3997 return None 3998 3999 def _can_fuse_nodes_with_compatible_ranges(self, node1, node2): 4000 # Here we try to fuse SchedulerNode/FusedSchedulerNode with compatible ranges 4001 # e.g. (s0, s1, s2) and (s0 * s1 * s2) 4002 _, (vars1, reduce1) = node1.group 4003 _, (vars2, reduce2) = node2.group 4004 4005 c1 = reduce1 == () and reduce2 == () 4006 c2 = math.prod(vars1) == math.prod(vars2) 4007 c3 = len(vars1) == 1 or len(vars2) == 1 4008 if not (c1 and c2 and c3): 4009 return False 4010 4011 node_to_recomp = node1 if len(vars1) < len(vars2) else node2 4012 ref_node = node2 if len(vars1) < len(vars2) else node1 4013 4014 # We can not recompute sizes and body for nodes other than SchedulerNode 4015 # TODO: we can extend fusion support with compatible ranges for FusedSchedulerNode 4016 if isinstance(node_to_recomp, FusedSchedulerNode): 4017 return False 4018 4019 # It may happen that node1 and node2 compatible number of elements 4020 # but different original ranges, for example: 4021 # {d0: s0, d1: s1, d2: s2} vs {d0: s0*s1*s2} 4022 # See https://github.com/pytorch/pytorch/pull/120077/files#r1500427848 for more details 4023 # TODO: we can fix if it allows us to CSE at least one of the variables 4024 4025 assert isinstance(node_to_recomp, SchedulerNode) 4026 if isinstance(node_to_recomp.node, ir.TemplateBuffer): 4027 return False 4028 assert isinstance(node_to_recomp.node, ir.ComputedBuffer) 4029 # node.data.get_size() is a cheaper version of node.get_read_writes().var_ranges 4030 # but without variable name 4031 ranges2 = node_to_recomp.node.data.get_size() 4032 ranges1 = None 4033 if isinstance(ref_node, FusedSchedulerNode): 4034 ranges_set = set() 4035 for snode in ref_node.snodes: 4036 if isinstance(snode.node, ir.TemplateBuffer): 4037 break 4038 assert isinstance(snode.node, ir.ComputedBuffer) 4039 ranges_set.add(tuple(snode.node.data.get_size())) 4040 4041 if len(ranges_set) != 1: 4042 return False 4043 4044 ranges1 = list(next(iter(ranges_set))) 4045 else: 4046 assert isinstance(ref_node, SchedulerNode) 4047 assert isinstance(ref_node.node, ir.ComputedBuffer) 4048 ranges1 = ref_node.node.data.get_size() 4049 4050 if ranges1 != ranges2: 4051 return False 4052 4053 return True 4054 4055 def _can_fuse_horizontal_impl(self, node1, node2): 4056 assert isinstance(node1, (FusedSchedulerNode, SchedulerNode)) 4057 assert isinstance(node2, (FusedSchedulerNode, SchedulerNode)) 4058 if any( 4059 isinstance(node, OuterLoopFusedSchedulerNode) for node in (node1, node2) 4060 ): 4061 return False 4062 return self._why_fuse_nodes(node1, node2) is not None 4063 4064 def can_fuse_horizontal(self, node1, node2): 4065 if node1.is_template() or node2.is_template(): 4066 return False 4067 if ( 4068 len(node1.get_nodes()) + len(node2.get_nodes()) 4069 > config.cpp.max_horizontal_fusion_size 4070 ): 4071 return False 4072 4073 return self._can_fuse_horizontal_impl(node1, node2) 4074 4075 def _get_outer_loop_fusion_depth(self, node1, node2): 4076 DISABLE_OUTER_LOOP_FUSION = 0 4077 if not all( 4078 type(node) 4079 in (OuterLoopFusedSchedulerNode, FusedSchedulerNode, SchedulerNode) 4080 for node in (node1, node2) 4081 ): 4082 return DISABLE_OUTER_LOOP_FUSION 4083 4084 _node1 = ( 4085 node1.get_outer_nodes()[-1] 4086 if isinstance(node1, OuterLoopFusedSchedulerNode) 4087 else node1 4088 ) 4089 assert isinstance(_node1, (FusedSchedulerNode, SchedulerNode)) 4090 _node2 = ( 4091 node2.get_outer_nodes()[0] 4092 if isinstance(node2, OuterLoopFusedSchedulerNode) 4093 else node2 4094 ) 4095 assert isinstance(_node2, (FusedSchedulerNode, SchedulerNode)) 4096 4097 _, (vars1, reduce1) = _node1.group 4098 _, (vars2, reduce2) = _node2.group 4099 if vars1 == () and vars2 == () and reduce1 != () and reduce2 != (): 4100 # Reduction only 4101 return DISABLE_OUTER_LOOP_FUSION 4102 if all(type(node) is OuterLoopFusedSchedulerNode for node in (node1, node2)): 4103 return ( 4104 node1.outer_loop_fusion_depth 4105 if node1.outer_loop_fusion_depth == node2.outer_loop_fusion_depth 4106 else DISABLE_OUTER_LOOP_FUSION 4107 ) 4108 outer_loop_fusion_depth = min(len(vars1), len(vars2)) 4109 if ( 4110 outer_loop_fusion_depth >= 1 4111 and vars1[:outer_loop_fusion_depth] == vars2[:outer_loop_fusion_depth] 4112 ): 4113 if any( 4114 type(node) is OuterLoopFusedSchedulerNode for node in (node1, node2) 4115 ): 4116 _compare_node = ( 4117 node1 if type(node1) is OuterLoopFusedSchedulerNode else node2 4118 ) 4119 if _compare_node.outer_loop_fusion_depth == outer_loop_fusion_depth: 4120 # Same outer loop fusion depth as prev nodes in OuterLoopFusedSchedulerNode 4121 return outer_loop_fusion_depth 4122 else: 4123 return DISABLE_OUTER_LOOP_FUSION 4124 else: 4125 # First 2 nodes to generate OuterLoopFusedSchedulerNode 4126 return outer_loop_fusion_depth 4127 return DISABLE_OUTER_LOOP_FUSION 4128 4129 def can_fuse_vertical_outer_loop(self, node1, node2): 4130 return ( 4131 not node1.is_template() 4132 and not node2.is_template() 4133 and node1.get_operation_names() & node2.ancestors 4134 and not ( 4135 self._can_fuse_horizontal_impl(node1, node2) 4136 and not node1.is_reduction() 4137 ) 4138 and self._get_outer_loop_fusion_depth(node1, node2) >= 1 4139 ) 4140 4141 def get_fusion_pair_priority(self, node1, node2): 4142 if self.can_fuse_vertical_outer_loop(node1, node2): 4143 # Outer loop fusion with lower priority 4144 return 1 4145 else: 4146 return 0 4147 4148 def can_fuse_vertical(self, node1, node2): 4149 if node2.is_template(): 4150 # TODO(jgong5): support pre-op fusion with template 4151 return False 4152 if node1.is_template(): 4153 return not node2.is_reduction() 4154 return ( 4155 self._can_fuse_horizontal_impl(node1, node2) and not node1.is_reduction() 4156 ) or self.can_fuse_vertical_outer_loop(node1, node2) 4157 4158 def try_loop_split(self, nodes: List[SchedulerNode]): 4159 """ 4160 Apply loop split optimization. 4161 When one of the indexing_exprs contains a division, we eliminate the division by splitting the loop 4162 to avoid non-contiguous loads, subject to the following conditions: 4163 1. No reduction and no mudular index for all nodes. 4164 2. Only one node's one indexing_exprs contains a division, according to this indexing_exprs, 4165 we can get the dimension that needs to be split, and the split dimension is contiguous 4166 in all other indexing_exprs. 4167 4168 For example, if the node's var_ranges: {z0: 2, z1: 9216, z2: 960} and indexing_exprs: 4169 {'index0': 8847360*z0 + 960*z1 + z2, 'index1': 32*z0 + (z2//30), 'index2': z2}, 4170 we will split z2 -> 30*z2 + z3, then the node's var_ranges will be changed to 4171 {z0: 2, z1: 9216, z2: 32, z3: 30} and indexing_exprs will be changed to 4172 {'index0': 8847360*z0 + 960*z1 + 30*z2 + z3, 'index1': 32*z0 + z2, 'index2': 30*z2 + z3}. 4173 """ 4174 4175 # No reduction and no mudular 4176 if any( 4177 len(node.group[1][1]) != 0 4178 or any( 4179 expr.has(ModularIndexing) for expr in node._body.indexing_exprs.values() 4180 ) 4181 for node in nodes 4182 ): 4183 return nodes 4184 4185 split_var = None 4186 split_number = None 4187 divide_index_name = None 4188 num_div = 0 4189 match_div = False 4190 matched_node = None 4191 4192 for node in nodes: 4193 assert isinstance(node.node, ir.ComputedBuffer) 4194 _, original_body, _ = node.node.get_default_sizes_body() 4195 for name, expr in original_body.indexing_exprs.items(): 4196 num_div += expr.count(FloorDiv) 4197 if num_div > 1: 4198 return nodes 4199 if expr.count(FloorDiv) == 1: 4200 div_expr = expr.find(FloorDiv).pop() 4201 split_var = div_expr.args[0] 4202 split_number = div_expr.args[1] 4203 divide_index_name = name 4204 if ( 4205 isinstance(split_number, sympy.core.numbers.Integer) 4206 and isinstance(split_var, sympy.core.symbol.Symbol) 4207 and split_var in original_body.iter_vars 4208 and divide_index_name is not None 4209 and all( 4210 stride_at_vec_range(expr, split_var) == 1 4211 for name, expr in original_body.indexing_exprs.items() 4212 if name != divide_index_name 4213 ) 4214 ): 4215 match_div = True 4216 matched_node = node 4217 4218 # Only one node contains a division, and the split dimension is contiguous in all other indexing_exprs. 4219 if not match_div: 4220 return nodes 4221 4222 extra_indexing_constraints = None 4223 4224 def loop_split(sizes, body, vars): 4225 index_size, reduce_size = sizes 4226 index_vars, reduce_vars = vars 4227 split_idx = index_vars.index(split_var) 4228 new_index_size = index_size.copy() 4229 new_index_size[split_idx] = index_size[split_idx] // split_number 4230 new_index_size.insert(split_idx + 1, split_number) 4231 (new_index_vars, _), var_ranges = dependencies.index_vars_no_squeeze( 4232 new_index_size, reduce_size, prefix="y" 4233 ) 4234 iter_vars = new_index_vars.copy() 4235 divisor_var = iter_vars.pop(split_idx + 1) 4236 iter_vars[split_idx] = split_number * iter_vars[split_idx] + divisor_var 4237 body = ir.LoopBody( 4238 body, [iter_vars, reduce_vars], var_ranges, new_index_vars, reduce_vars 4239 ) 4240 nonlocal extra_indexing_constraints 4241 if not extra_indexing_constraints: 4242 extra_indexing_constraints = ( 4243 body.var_ranges, 4244 list(body.indexing_exprs.values()), 4245 ) 4246 return ( 4247 (new_index_size, reduce_size), 4248 body, 4249 (new_index_vars, reduce_vars), 4250 ) 4251 4252 # Here decide the final loop order 4253 for node in nodes: 4254 if node == matched_node: 4255 node.recompute_size_and_body(recompute_sizes_body_func=loop_split) 4256 for node in nodes: 4257 if node != matched_node: 4258 node.recompute_size_and_body( 4259 extra_indexing_constraints=extra_indexing_constraints, 4260 recompute_sizes_body_func=loop_split, 4261 ) 4262 4263 return nodes 4264 4265 def codegen_outer_loop_node( 4266 self, 4267 node: OuterLoopFusedSchedulerNode, 4268 ): 4269 """ 4270 Generate the code for the outer loop fused scheduler node. 4271 1. Codegen with fused outer loop: depends on the analysis of 4272 the outer loop fused scheduler node, with or without the local buffer. 4273 2. If failed, fallback to standard codegen. 4274 """ 4275 kernel_group = self.kernel_group 4276 generated_cpp_vec_kernel_count = metrics.generated_cpp_vec_kernel_count 4277 cpp_kernel_proxy_list: List[CppKernelProxy] = [] 4278 nodes_list: List[List[SchedulerNode]] = [] 4279 assert isinstance(node, OuterLoopFusedSchedulerNode) 4280 4281 def try_outer_loop_fusion_with_local_buf(node: OuterLoopFusedSchedulerNode): 4282 """ 4283 Codegen code with fused outer loop and local Buffer. 4284 """ 4285 assert isinstance(node, OuterLoopFusedSchedulerNode) 4286 cpp_kernel_proxy_list.clear() 4287 nodes_list.clear() 4288 4289 def get_call_ranges(node: BaseSchedulerNode): 4290 assert isinstance(node, (SchedulerNode, FusedSchedulerNode)) 4291 nodes: List[SchedulerNode] = node.get_nodes() # type: ignore[assignment] 4292 _, (group, reduction_group) = max( 4293 nodes, key=lambda x: int(x.is_reduction()) 4294 ).group 4295 call_ranges = tuple(group) + tuple(reduction_group) 4296 return call_ranges 4297 4298 local_buffers: List[ir.Buffer] = [] 4299 # Map local buffer name to a list of global buffers 4300 local_to_global_buffers: Dict[str, List[ir.Buffer]] = {} 4301 if all( 4302 len(get_call_ranges(_node)) == node.outer_loop_fusion_depth + 1 4303 for _node in node.get_outer_nodes() 4304 ): 4305 # Ref to the typical case of local buffer 4306 # in https://github.com/pytorch/pytorch/blob/ 4307 # 1115a25c36340554442f28f9570abd42f0aface2/aten/src/ATen/native/cpu/SoftMaxKernel.cpp#L159 4308 # where the buffer is with size of last dim and contiguous. 4309 # Only support this typical case at first. 4310 visited_scheduler_nodes: Set[str] = set() 4311 for scheduler_node in node.get_nodes(): 4312 # all users inside same OuterLoopFusedSchedulerNode 4313 assert isinstance(scheduler_node, SchedulerNode) 4314 visited_scheduler_nodes.add(scheduler_node.get_name()) 4315 if ( 4316 scheduler_node.is_reduction() 4317 or len(scheduler_node.get_outputs()) != 1 4318 ): 4319 continue 4320 4321 scheduler_buffer = scheduler_node.get_outputs()[0] 4322 if all( 4323 user.node in node.get_nodes() for user in scheduler_buffer.users 4324 ): 4325 global_buffer = scheduler_buffer.node 4326 assert isinstance(global_buffer, ir.ComputedBuffer) 4327 global_buffer_layout = global_buffer.get_layout() 4328 size_offset = node.outer_loop_fusion_depth - len( 4329 get_call_ranges(scheduler_node) 4330 ) 4331 4332 def is_all_write_read_contiguous(): 4333 contiguous_index_expr = 0 4334 stride = 1 4335 for var, range in reversed( 4336 scheduler_node._body.var_ranges.items() 4337 ): 4338 contiguous_index_expr += stride * var 4339 stride *= range 4340 write_index_expr = scheduler_node._body.get_write_expr( 4341 scheduler_buffer.get_name() 4342 ) 4343 4344 def is_contiguous_index(x): 4345 return x == contiguous_index_expr 4346 4347 return is_contiguous_index(write_index_expr) and all( 4348 isinstance(user.node, SchedulerNode) 4349 and is_contiguous_index( 4350 user.node._body.get_read_expr( 4351 scheduler_buffer.get_name() 4352 ), 4353 ) 4354 for user in scheduler_buffer.users 4355 ) 4356 4357 if not ( 4358 global_buffer_layout.is_contiguous() 4359 and is_all_write_read_contiguous() 4360 ): 4361 continue 4362 # Local Buffer is a view of global buffer 4363 local_buffer_layout = ir.FixedLayout( 4364 global_buffer_layout.device, 4365 global_buffer_layout.dtype, 4366 global_buffer_layout.size[size_offset:], 4367 global_buffer_layout.stride[size_offset:], 4368 ) 4369 4370 def try_share_local_buffer(local_buffer_layout, local_buffers): 4371 for local_buf in local_buffers: 4372 if local_buffer_layout == local_buf.layout and all( 4373 all( 4374 user.node.get_name() in visited_scheduler_nodes 4375 for user in V.graph.scheduler.name_to_buf[ 4376 global_buffer.name 4377 ].users 4378 ) 4379 for global_buffer in local_to_global_buffers[ 4380 local_buf.name 4381 ] 4382 if global_buffer.name is not None 4383 ): 4384 return local_buf 4385 return None 4386 4387 local_buf_prefix = "local_buffer_data" 4388 # Share existing local buffer 4389 local_buffer_used = try_share_local_buffer( 4390 local_buffer_layout, local_buffers 4391 ) 4392 if not local_buffer_used: 4393 # Create new local buffer 4394 local_buffer_used = ir.Buffer( 4395 f"{local_buf_prefix}_{len(local_buffers)}", 4396 local_buffer_layout, 4397 ) 4398 local_buffers.append(local_buffer_used) 4399 local_to_global_buffers[local_buffer_used.name] = [] 4400 local_to_global_buffers[local_buffer_used.name].append( 4401 global_buffer, 4402 ) 4403 4404 with LocalBufferContext(kernel_group.args) as scope: 4405 if len(local_buffers) > 0: 4406 for local_buffer in local_buffers: 4407 assert local_buffer.name is not None 4408 scope.add_local_buffer( 4409 local_buffer, local_to_global_buffers[local_buffer.name] 4410 ) 4411 for _node in node.get_outer_nodes(): 4412 assert isinstance(_node, (FusedSchedulerNode, SchedulerNode)) 4413 cpp_kernel_proxy = CppKernelProxy(kernel_group) 4414 cpp_kernel_proxy.codegen_nodes(_node.get_nodes()) # type: ignore[arg-type] 4415 cpp_kernel_proxy_list.append(cpp_kernel_proxy) 4416 nodes_list.append(_node.get_nodes()) # type: ignore[arg-type] 4417 4418 if not node.check_outer_fusion_loop_level_attr( 4419 cpp_kernel_proxy_list, node.outer_loop_fusion_depth 4420 ): 4421 return False 4422 metrics.cpp_outer_loop_fused_inner_counts.append( 4423 metrics.CppOuterLoopFusedCount( 4424 len(cpp_kernel_proxy_list), 4425 local_buffer_number=len(scope.local_buffers), 4426 ) 4427 ) 4428 outer_fusion_cpp_kernel_proxy = node.merge_outer_fusion_kernels( 4429 cpp_kernel_proxy_list, 4430 ) 4431 kernel_group.finalize_kernel( 4432 outer_fusion_cpp_kernel_proxy, 4433 [_node for _nodes in nodes_list for _node in _nodes], 4434 ) 4435 4436 return True 4437 4438 if not try_outer_loop_fusion_with_local_buf(node): 4439 # Reset generated_cpp_vec_kernel_count to codegen again 4440 metrics.generated_cpp_vec_kernel_count = generated_cpp_vec_kernel_count 4441 cpp_kernel_proxy_list.clear() 4442 nodes_list.clear() 4443 # Similar as comment in 4444 # https://github.com/pytorch/pytorch/blob/469383755fe416eb1c41fa724762ad3eaecdff07/torch/_inductor/codegen/cpp.py#L3269-L3272 4445 # Kernels share the same global contexts like V.graph.wrapper_code, V.kernel.args. 4446 with torch._inductor.config.patch(inplace_buffers=False): 4447 for _node in node.get_outer_nodes(): 4448 assert isinstance(_node, (FusedSchedulerNode, SchedulerNode)) 4449 _nodes: List[SchedulerNode] = _node.get_nodes() # type: ignore[assignment] 4450 cpp_kernel_proxy = CppKernelProxy(kernel_group) 4451 cpp_kernel_proxy.codegen_nodes(_nodes) 4452 kernel_group.finalize_kernel(cpp_kernel_proxy, _nodes) 4453 4454 def codegen_node( 4455 self, 4456 node: Union[OuterLoopFusedSchedulerNode, FusedSchedulerNode, SchedulerNode], 4457 ): 4458 """ 4459 Turn an set of pre-fused nodes into a C++ kernel. 4460 """ 4461 kernel_group = self.kernel_group 4462 4463 if isinstance(node, OuterLoopFusedSchedulerNode): 4464 self.codegen_outer_loop_node(node) 4465 else: 4466 nodes: List[SchedulerNode] = node.get_nodes() # type: ignore[assignment] 4467 nodes = self.try_loop_split(nodes) 4468 cpp_kernel_proxy = CppKernelProxy(kernel_group) 4469 cpp_kernel_proxy.codegen_nodes(nodes) 4470 kernel_group.finalize_kernel(cpp_kernel_proxy, nodes) 4471 4472 args_num = self._get_scheduled_num_args() 4473 if args_num > CppScheduling.MAX_FUSED_KERNEL_ARGS_NUM: 4474 self._set_flush_status(True) 4475 4476 def is_cpp_template(self, node: BaseSchedulerNode) -> bool: 4477 return isinstance(node, SchedulerNode) and isinstance( 4478 node.node, ir.CppTemplateBuffer 4479 ) 4480 4481 def codegen_template( 4482 self, 4483 template_node: BaseSchedulerNode, 4484 epilogue_nodes: Sequence[BaseSchedulerNode], 4485 ): 4486 """ 4487 Codegen a CPP template, possibly with fused epilogues 4488 """ 4489 counters["inductor"]["cpp_epilogue_fusion_counter"] += len(epilogue_nodes) 4490 assert self.is_cpp_template( 4491 template_node 4492 ), "Template node passed to CppScheduler.codegen_template must be a SchedulerNode that wraps a CppTemplateBuffer" 4493 template_node = cast(SchedulerNode, template_node) 4494 _, (_, rnumel) = template_node.group 4495 assert rnumel == () 4496 ctb: ir.CppTemplateBuffer = cast(ir.CppTemplateBuffer, template_node.node) 4497 epilogue_ir_nodes: List[Optional[ir.Operation]] = [ 4498 n.node for n in epilogue_nodes 4499 ] 4500 assert all( 4501 isinstance(n, ir.ComputedBuffer) for n in epilogue_ir_nodes 4502 ), "Epilogue nodes must all be instances of ir.ComputedBuffer" 4503 4504 def template_buffer_has_other_users( 4505 template_buffer, outputs_by_name, epilogue_nodes 4506 ): 4507 assert template_buffer.get_name() in outputs_by_name 4508 users = outputs_by_name[template_buffer.get_name()].users 4509 return not all( 4510 isinstance(user.node, BaseSchedulerNode) 4511 and user.node.node in epilogue_nodes 4512 for user in users 4513 ) 4514 4515 flag_template_buffer_has_other_users = template_buffer_has_other_users( 4516 ctb, template_node.outputs_by_name, epilogue_ir_nodes 4517 ) 4518 kernel, render = ctb.make_kernel_render( 4519 ctb, 4520 flag_template_buffer_has_other_users=flag_template_buffer_has_other_users, 4521 epilogue_nodes=epilogue_ir_nodes, 4522 ) 4523 with kernel: 4524 for node in [template_node, *epilogue_nodes]: 4525 node.mark_run() # type: ignore[attr-defined] 4526 src_code = render() 4527 4528 with V.set_kernel_handler(kernel): 4529 node_schedule = [template_node, *epilogue_nodes] 4530 kernel_name = self.define_kernel(src_code, node_schedule, kernel.args) 4531 kernel.call_kernel(kernel_name, ctb) 4532 V.graph.removed_buffers |= kernel.removed_buffers 4533 self.scheduler.free_buffers() 4534 4535 def _get_scheduled_num_args(self): 4536 return self.kernel_group.get_num_args() 4537 4538 def ready_to_flush(self): 4539 return self._ready_to_flush 4540 4541 def codegen_sync(self): 4542 pass 4543 4544 def define_kernel(self, src_code, nodes, kernel_args=None): 4545 wrapper = V.graph.wrapper_code 4546 fused_name = ( 4547 get_fused_kernel_name(nodes, config.cpp.descriptive_names) 4548 if config.cpp.descriptive_names 4549 else "" 4550 ) 4551 kernel_name = "_".join(["cpp", fused_name, wrapper.next_kernel_suffix()]) 4552 kernel_decl_name = kernel_name if V.graph.cpp_wrapper else "kernel" 4553 src_code = src_code.replace(str(Placeholder.KERNEL_NAME), kernel_decl_name) 4554 src_code = src_code.replace(str(Placeholder.DESCRIPTIVE_NAME), kernel_name) 4555 # TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does 4556 # not use BracesBuffer, so we have no good indicator of a C++ buffer atm. 4557 src_code = src_code.replace("#pragma CMT", "//") 4558 4559 compile_wrapper = IndentedBuffer() 4560 args = self.kernel_group.args if kernel_args is None else kernel_args 4561 _, _, arg_types = args.cpp_argdefs() 4562 if not V.graph.cpp_wrapper: 4563 compile_wrapper.writeline(f"async_compile.cpp_pybinding({arg_types!r}, '''") 4564 compile_wrapper.splice(src_code, strip=True) 4565 if not V.graph.cpp_wrapper: 4566 compile_wrapper.writeline("''')") 4567 wrapper.define_kernel(kernel_name, compile_wrapper.getvalue(), cuda=False) 4568 return kernel_name 4569 4570 def flush(self): 4571 src_code = self.kernel_group.codegen_group() 4572 if src_code: 4573 kernel_name = self.define_kernel( 4574 src_code, self.kernel_group.scheduled_nodes 4575 ) 4576 self.kernel_group.call_kernel(V.graph.wrapper_code, kernel_name) 4577 self.reset_kernel_group() 4578 self._set_flush_status(False) 4579 4580 4581class KernelGroup: 4582 def __init__(self): 4583 super().__init__() 4584 self.args = KernelArgs() 4585 self.loops_code = BracesBuffer() 4586 self.ws = WorkSharing(self.loops_code) 4587 self.stack = contextlib.ExitStack() 4588 self.stack.enter_context(self.ws) 4589 self.scheduled_nodes = [] 4590 4591 def new_kernel(self, cls, *args): 4592 return cls(self.args, parallel_num_threads(), *args) 4593 4594 def finalize_kernel(self, new_kernel, nodes): 4595 self.scheduled_nodes += nodes 4596 code = self.loops_code 4597 ws = self.ws 4598 new_kernel.codegen_loops(code, ws) 4599 4600 def get_num_args(self): 4601 arg_defs, call_args, arg_types = self.args.cpp_argdefs() 4602 args_num = len(arg_defs) 4603 return args_num 4604 4605 def codegen_group(self, name=None) -> str: 4606 self.stack.close() 4607 if not self.scheduled_nodes: 4608 return "" 4609 code = BracesBuffer() 4610 # 1. Include header files 4611 # TODO: support kernel profile on other platforms 4612 enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [ 4613 "linux", 4614 "win32", 4615 ] 4616 if enable_kernel_profile: 4617 code.writelines(["#include <ATen/record_function.h>"]) 4618 code.writeline(codecache.cpp_prefix()) 4619 4620 # 2. Function definition 4621 kernel_decl_name = str(Placeholder.KERNEL_NAME) if name is None else name 4622 kernel_name = str(Placeholder.DESCRIPTIVE_NAME) if name is None else name 4623 arg_defs, _, _ = self.args.cpp_argdefs() 4624 arg_defs = ",\n".ljust(25).join(arg_defs) 4625 func_export_decl = get_export_declaration() 4626 code.writeline( 4627 f'extern "C" {func_export_decl} void {kernel_decl_name}({arg_defs})' 4628 ) 4629 4630 # 3. Function body 4631 with code.indent(): 4632 if enable_kernel_profile: 4633 graph_id = V.graph.graph_id 4634 prefix = "graph_" + str(graph_id) + "_" if graph_id is not None else "" 4635 code.writelines( 4636 [ 4637 f'RECORD_FUNCTION("{prefix + kernel_name}", c10::ArrayRef<c10::IValue>({{}}));' 4638 ] 4639 ) 4640 for old, new in self.args.aliases(): 4641 code.writeline(f"auto {old} = {new};") 4642 code.splice(self.loops_code) 4643 return code.getvalue() 4644 4645 def call_kernel(self, wrapper, kernel_name): 4646 _, call_args, arg_types = self.args.cpp_argdefs() 4647 wrapper.generate_kernel_call( 4648 kernel_name, call_args, cuda=False, arg_types=arg_types 4649 ) 4650 4651 4652class CppWrapperKernelGroup(KernelGroup): 4653 def __init__(self): 4654 super().__init__() 4655 self.args = CppWrapperKernelArgs() 4656 4657 4658class WorkSharing: 4659 def __init__(self, code): 4660 self.code = code 4661 self.in_parallel = False 4662 self.num_threads = None 4663 self.stack = contextlib.ExitStack() 4664 4665 def parallel(self, threads): 4666 if self.in_parallel and threads != self.num_threads: 4667 # wrong number of threads 4668 self.close() 4669 if not self.in_parallel: 4670 self.num_threads = threads 4671 self.in_parallel = True 4672 if config.cpp.dynamic_threads: 4673 self.code.writeline("#pragma omp parallel") 4674 else: 4675 self.code.writeline(f"#pragma omp parallel num_threads({threads})") 4676 self.stack.enter_context(self.code.indent()) 4677 self.code.writeline( 4678 "int tid = omp_get_thread_num();", 4679 ) 4680 4681 def single(self): 4682 if self.in_parallel: 4683 self.code.writeline("#pragma omp single") 4684 return self.in_parallel 4685 4686 def close(self): 4687 self.stack.close() 4688 self.in_parallel = False 4689 4690 def __enter__(self): 4691 self.stack.__enter__() 4692 return self 4693 4694 def __exit__(self, exc_type, exc_val, exc_tb): 4695 self.stack.__exit__(exc_type, exc_val, exc_tb) 4696 4697 4698@dataclasses.dataclass 4699class LoopLevel: 4700 var: Optional[sympy.Expr] = None 4701 size: Optional[sympy.Expr] = None 4702 offset: sympy.Expr = sympy.Integer(0) 4703 steps: sympy.Expr = sympy.Integer(1) 4704 parallel: int = 0 4705 simd_omp: bool = False 4706 simd_vec: bool = False 4707 collapsed: bool = False 4708 is_reduction: bool = False 4709 parent: Optional["LoopLevel"] = None 4710 # the next inner level of the loop, empty if it is inner-most 4711 # contains >1 LoopLevel if the inner level of loop is split 4712 inner: List["LoopLevel"] = dataclasses.field(default_factory=list) 4713 # kernel assigned to this loop level, only valid when it is a leaf 4714 kernel: Optional[CppKernel] = None 4715 4716 def __post_init__(self): 4717 # Regarding the C++/OpenMP backend, `cpu_vec_isa.pick_vec_isa()` to check 4718 # vectorization ISA is a time-consuming and one-shot operation. It leads 4719 # to taking a longer time to import `codegen.cpp` package because the 4720 # `LoopLevel` of the package is decorated by `@dataclasses.dataclass` while 4721 # the decorator will invoke `cpu_vec_isa.pick_vec_isa()` to initialize the 4722 # `simd_nelements` of the `LoopLevel`. It might introduce additional compilation 4723 # overhead to the Triton backend. Therefore, we moved the `simd_nelements` to 4724 # `__post_init__` 4725 picked_vec_isa: cpu_vec_isa.VecISA = cpu_vec_isa.pick_vec_isa() 4726 self.simd_nelements: int = picked_vec_isa.nelements() if picked_vec_isa else 0 4727 4728 def get_kernels(self) -> List[CppKernel]: 4729 """Get all kernel objects under this loop level""" 4730 if self.kernel: 4731 return [self.kernel] 4732 kernels = [] 4733 for loop in self.inner: 4734 kernels += loop.get_kernels() 4735 return kernels 4736 4737 def get_root(self): 4738 """Get all kernel objects under this loop level""" 4739 root = self 4740 while root.parent: 4741 root = root.parent 4742 return root 4743 4744 def set_kernel(self, kernel: CppKernel): 4745 """ 4746 Set the kernel under this loop level. No split is allowed under 4747 this loop level. 4748 """ 4749 if not self.inner: 4750 self.kernel = kernel 4751 loop: Optional[LoopLevel] = self 4752 assert loop is not None 4753 return 4754 assert len(self.inner) == 1 4755 self.inner[0].set_kernel(kernel) 4756 4757 def get_loops_at(self, depth) -> List["LoopLevel"]: 4758 if depth == 0: 4759 return [self] 4760 else: 4761 loops = [] 4762 for loop in self.inner: 4763 loops += loop.get_loops_at(depth - 1) 4764 return loops 4765 4766 def split_with_tiling(self, depth, factor): 4767 def clone_inner(): 4768 inner = [] 4769 if self.inner: 4770 for loop in self.inner: 4771 inner.append(loop.clone()) 4772 return inner 4773 4774 def do_split_with_tiling(): 4775 sympy_factor = sympy.Integer(factor) 4776 4777 offset = FloorDiv(self.size, sympy_factor) * sympy_factor 4778 main_loop = LoopLevel(self.var, offset) 4779 main_loop.steps = sympy_factor 4780 main_loop.parallel = self.parallel 4781 main_loop.collapsed = False 4782 main_loop.is_reduction = self.is_reduction 4783 main_loop.inner = clone_inner() 4784 if main_loop.inner: 4785 for loop in main_loop.inner: 4786 loop.parent = main_loop 4787 4788 tail_loop = LoopLevel(self.var, self.size) 4789 tail_loop.offset = offset 4790 tail_loop.parallel = self.parallel 4791 tail_loop.collapsed = False 4792 tail_loop.is_reduction = self.is_reduction 4793 tail_loop.inner = clone_inner() 4794 if tail_loop.inner: 4795 for loop in tail_loop.inner: 4796 loop.parent = tail_loop 4797 4798 return main_loop, tail_loop 4799 4800 if depth == 0: 4801 main_loop, tail_loop = do_split_with_tiling() 4802 parent = self.parent 4803 if parent: 4804 parent.inner = [main_loop, tail_loop] 4805 main_loop.parent = parent 4806 tail_loop.parent = parent 4807 return main_loop, tail_loop 4808 else: 4809 assert len(self.inner) == 1 4810 return self.inner[0].split_with_tiling(depth - 1, factor) 4811 4812 def clone(self): 4813 loop = copy(self) 4814 loop.inner = [] 4815 if self.inner: 4816 for inner_loop in self.inner: 4817 inner_loop_clone = inner_loop.clone() 4818 inner_loop_clone.parent = loop 4819 loop.inner.append(inner_loop_clone) 4820 loop.kernel = deepcopy(self.kernel) 4821 return loop 4822 4823 def lines(self): 4824 offset_expr = cexpr_index(self.offset) 4825 size_expr = cexpr_index(self.size) 4826 if config.cpp.no_redundant_loops and offset_expr == size_expr: 4827 return None 4828 simd = ( 4829 f"simd simdlen({self.simd_nelements}) " 4830 if self.simd_omp and self.simd_nelements > 1 4831 else "" 4832 ) 4833 if self.parallel: 4834 # TODO(jansel): look into chunk size and other schedules 4835 line1 = "#pragma omp for" 4836 if self.parallel > 1: 4837 line1 += f" collapse({self.parallel})" 4838 if self.simd_omp: 4839 line1 = line1.replace(" for ", f" for {simd}") 4840 elif self.simd_vec: 4841 line1 = "" 4842 elif self.simd_omp: 4843 line1 = f"#pragma omp {simd}" 4844 elif not self.is_reduction and cpp_builder.is_gcc(): 4845 line1 = "#pragma GCC ivdep" 4846 else: 4847 line1 = "" 4848 offset_str = f"{INDEX_TYPE} {self.var}={offset_expr}" 4849 size_str = f"{self.var}<{size_expr}" 4850 if self.steps.is_number: 4851 steps_str = f"{self.var}+={cexpr_index(self.steps)}" 4852 else: 4853 # If the step size is 0, change it to 1 because a step size of 0 4854 # will cause floating point exception (core dump) during parallelization. 4855 steps_str = ( 4856 f"{self.var}+=({cexpr_index(self.steps)} == 0 ? " 4857 f"1 : {cexpr_index(self.steps)})" 4858 ) 4859 line2 = f"for({offset_str}; {size_str}; {steps_str})" 4860 if self.collapsed or not line1: 4861 return [line2] 4862 return [line1, line2] 4863 4864 4865@dataclasses.dataclass 4866class LoopNestWithSplit: 4867 """ 4868 A loop-nest like structure but with some loop level split along 4869 the loop range into the main tiling loop and the tail. It is built 4870 with the `build` method as a loop nest and then split with 4871 `split_with_tiling` at some depth. 4872 4873 A typical case is for vectorization where we typically split at the inner-most 4874 loop level. A more complicated case is 2D tiling where we split at 4875 both inner-most and outer levels. 4876 """ 4877 4878 root: Optional[List[LoopLevel]] = None 4879 kernel: Optional[CppKernel] = None 4880 4881 @staticmethod 4882 def build(kernel: CppKernel): 4883 """Build a LoopNest with the given `kernel` as the leaf""" 4884 itervars = kernel.itervars 4885 ranges = kernel.ranges 4886 reduction_depth = kernel.reduction_depth 4887 assert reduction_depth is not None 4888 4889 root: List[LoopLevel] = [] 4890 levels: List[LoopLevel] = root 4891 loop: Optional[LoopLevel] = None 4892 for loop_idx, (var, size) in enumerate(zip(itervars, ranges)): 4893 loop = LoopLevel(var, size, parent=loop) 4894 if loop_idx >= reduction_depth: 4895 loop.is_reduction = kernel.is_reduction 4896 levels.append(loop) 4897 levels = loop.inner 4898 loop_nest = LoopNestWithSplit(root) 4899 if loop: 4900 loop.kernel = kernel 4901 else: 4902 loop_nest.kernel = kernel 4903 return loop_nest 4904 4905 def __bool__(self): 4906 return bool(self.root) 4907 4908 def get_loops_at(self, depth) -> List[LoopLevel]: 4909 """Get all the loop levels at the given `depth` (most outer loop has depth 0)""" 4910 loops: List[LoopLevel] = [] 4911 assert self.root is not None 4912 for loop in self.root: 4913 loops += loop.get_loops_at(depth) 4914 return loops 4915 4916 @cache_on_self 4917 def max_parallel_depth(self): 4918 """ 4919 Maximal allowed depth for parallelism: 4920 1) Levels without splitting and 4921 2) All reduction or non-reduction levels 4922 When the loop is split at the top level, the max depth is 1. 4923 """ 4924 max_depth = 0 4925 assert self.root is not None 4926 loops = self.root 4927 if len(loops) > 1: 4928 return 1 4929 is_reduction = loops[0].is_reduction if loops else False 4930 while len(loops) == 1 and loops[0].is_reduction == is_reduction: 4931 max_depth += 1 4932 loops = loops[0].inner 4933 return max_depth 4934 4935 def is_reduction_only(self): 4936 """ 4937 Whether all the loops are for reduction. Reduction loops 4938 are always the inner most ones. 4939 """ 4940 return ( 4941 self.root is not None and len(self.root) > 0 and self.root[0].is_reduction 4942 ) 4943 4944 def mark_parallel(self, par_depth): 4945 assert ( 4946 par_depth <= self.max_parallel_depth() 4947 ), "Parallel depth cannot exceed the maximal allowed parallel depth" 4948 assert self.root is not None 4949 loops = self.root 4950 for loop in loops: 4951 loop.parallel = par_depth 4952 for i in range(1, par_depth): 4953 loops = loops[0].inner 4954 loops[0].collapsed = True 4955 4956 def split_with_tiling(self, depth, factor): 4957 """ 4958 Split the loop into main and tail loops at given `depth` so that the range 4959 of the main loop has range `floor_div(range, factor) * factor` and 4960 the tail loop handles the remainder. The main loop is tiled 4961 according to the `factor`. 4962 """ 4963 loops = self.get_loops_at(depth) 4964 assert len(loops) == 1 4965 split_loops = loops[0].split_with_tiling(0, factor) 4966 if depth == 0: 4967 self.root = split_loops 4968 return split_loops 4969 4970 def get_kernels(self) -> List[CppKernel]: 4971 """Get all kernel objects under this loop nest""" 4972 if self.kernel: 4973 return [self.kernel] 4974 kernels: List[CppKernel] = [] 4975 assert self.root is not None 4976 for loop in self.root: 4977 kernels += loop.get_kernels() 4978 return kernels 4979