1# mypy: allow-untyped-defs 2import collections 3import contextlib 4import dataclasses 5import functools 6import itertools 7import logging 8import re 9import textwrap 10import traceback 11from contextlib import nullcontext 12from functools import partial 13from typing import ( 14 Any, 15 Callable, 16 ClassVar, 17 Dict, 18 Iterable, 19 List, 20 Optional, 21 Sequence, 22 Set, 23 Tuple, 24 TYPE_CHECKING, 25 Union, 26) 27from unittest.mock import patch 28 29import sympy 30from sympy import Expr, Integer 31 32import torch._export.serde.schema as export_schema 33 34import torch._logging 35 36import torch.fx 37import torch.utils._pytree as pytree 38from torch._dynamo.device_interface import get_interface_for_device 39from torch._dynamo.utils import identity 40from torch._export.serde.serialize import GraphModuleSerializer 41from torch._higher_order_ops.auto_functionalize import can_auto_functionalize 42from torch._inductor import metrics 43from torch._prims_common import ( 44 compute_required_storage_length, 45 is_boolean_dtype, 46 is_float_dtype, 47 make_channels_last_strides_for, 48 StrideType, 49) 50from torch._subclasses.fake_tensor import get_schema_info 51from torch.fx.experimental.symbolic_shapes import ( 52 CallMethodKey, 53 compute_unbacked_bindings, 54 DivideByKey, 55 free_unbacked_symbols, 56 rebind_unbacked, 57 resolve_unbacked_bindings, 58 SymTypes, 59) 60from torch.utils._sympy.functions import CleanDiv, FloorDiv, ModularIndexing 61from torch.utils._sympy.symbol import SymT 62 63from . import config, dependencies 64from .codegen.common import index_prevent_reordering 65from .dependencies import ( 66 extract_free_unbacked_symbols, 67 extract_input_node_reduction_ranges, 68 extract_read_writes, 69 var_builder, 70) 71from .ops_handler import OpCounterCSE 72from .runtime.hints import ReductionHint 73from .runtime.runtime_utils import do_bench 74from .utils import ( 75 argsort, 76 cache_on_self, 77 ceildiv, 78 convert_shape_to_inductor, 79 convert_shape_to_symint, 80 developer_warning, 81 get_kernel_metadata, 82 is_dynamic, 83 is_gpu, 84 pad_listlike, 85 sympy_dot, 86 sympy_index_symbol, 87 sympy_index_symbol_with_prefix, 88 sympy_product, 89 sympy_subs, 90) 91from .virtualized import ops, V 92 93if TYPE_CHECKING: 94 from .graph import GraphLowering 95 96log = logging.getLogger(__name__) 97indent = functools.partial(textwrap.indent, prefix=" ") 98aten = torch.ops.aten 99 100""" [Note: Inductor IR] 101 102Inductor's IR is produced by executing 'lowering' code (see lowering.py). Each 103lowering is registered to a particular aten operator, and expects inputs that 104correspond to the aten schema. However, in place of torch Tensor inputs, lowerings 105expect Inductor TensorBox inputs. 106 107TensorBox IR represents torch tensors. Tensors are sometimes single objects owning 108storage, and sometimes views of another Tensor's storage. Mutating tensor operations 109(such as add_()) affect the underlying storage and any associated views. Other operations 110(such as .t_()) update metadata about the current view but don't modify the underlying storage. 111 112To model this in Inductor, the IR distinguishes between TensorBox, View, StorageBox and Buffer. 113 114TensorBox is the top level IR construct that any lowering should produce and maps to a torch.Tensor 115output from an operation. But just as torch.Tensors take different forms, TensorBox IR can 116reference View IR or directly reference StorageBox IRs. 117 118Some Inductor lowerings produce new sets of 'Box'es, while others (such as .t() or other view ops) 119may take an existing TensorBox and point it to a new underlying View IR. 120 121Tensors that directly own storage are represented as a chain of: 122TensorBox -> StorageBox -> Buffer 123where Buffer is a simple (1D) allocation, and StorageBox introduces the concept of a Layout. 124 125If you mutate the data of such a tensor, we swing the StorageBox pointer to point to a new buffer 126(leaving the old buffer unmodified and functionalizing the operation). 127 128Tensors backed by views add one more indirection to the IR. 129TensorBox -> View -> StorageBox -> Buffer 130In these cases, the underlying StorageBox/Buffer will be shared with the pre-view TensorBox. 131""" 132 133 134def validate_ir(node_or_nodes): 135 def _check_tensorbox(nodes): 136 # Could expand this to check deeper properties 137 # (e.g. TensorBox points to View or StorageBox) 138 if nodes is None: 139 pass 140 elif isinstance(nodes, (list, tuple)): 141 for node in nodes: 142 _check_tensorbox(node) 143 elif isinstance(nodes, dict): 144 for node in nodes.values(): 145 _check_tensorbox(node) 146 else: 147 assert isinstance( 148 nodes, 149 ( 150 torch._inductor.ir.ExpandView, 151 DynamicScalar, 152 AssertScalar, 153 TensorBox, 154 sympy.logic.boolalg.Boolean, 155 Expr, 156 EffectfulKernel, 157 ), 158 ), f"Found {type(nodes)}, which is not a supported top level IR node. See [Note: Inductor IR]" 159 160 # Be picky about the accepted data structure (don't use pytree here) 161 _check_tensorbox(node_or_nodes) 162 163 164def ops_wrapper(name): 165 assert isinstance(name, str) 166 167 def fn(*args, **kwargs): 168 return getattr(ops, name)(*args, **kwargs) 169 170 return fn 171 172 173def inverse_reorder(order): 174 inv_order = dict(zip(order, range(len(order)))) 175 176 def reindex(index): 177 assert len(index) == len(inv_order) 178 return [index[inv_order[i]] for i in range(len(index))] 179 180 return reindex 181 182 183def same_reorder(order): 184 def reindex(index): 185 assert len(index) == len(order) 186 return [index[order[i]] for i in range(len(index))] 187 188 return reindex 189 190 191def fuse_reindexing(reindex1, reindex2): 192 def reindex(index): 193 return reindex1(reindex2(index)) 194 195 return reindex 196 197 198NHWC_STRIDE_ORDER = [3, 0, 2, 1] 199NHWDC_STRIDE_ORDER = [4, 0, 3, 2, 1] 200 201 202def stride_order2fill_order(order): 203 """ 204 Convert stride order to fill order 205 For channel last format, 206 207 stride order = [3, 0, 2, 1] and fill order = [1, 3, 2, 0] 208 """ 209 lookup = {pos: idx for idx, pos in enumerate(order)} 210 fill_order = [lookup[i] for i in range(len(order))] 211 return fill_order 212 213 214def get_stride_order(seq: Sequence[int]) -> List[int]: 215 """ 216 Convert strides to stride order 217 """ 218 sorted_idx: List[int] = argsort(seq) 219 out = [0 for _ in range(len(seq))] 220 for i, elem in enumerate(sorted_idx): 221 out[elem] = i 222 return out 223 224 225def ir_node_to_tensor(x, guard_shape=True): 226 if x is None: 227 return None 228 229 shape_fn: Callable[[Expr], Union[int, Expr]] 230 if not guard_shape: 231 shape_fn = V.graph.sizevars.size_hint 232 else: 233 shape_fn = identity 234 size = [shape_fn(s) for s in x.get_size()] 235 stride: StrideType 236 if is_storage_and_layout(x): 237 stride = [shape_fn(s) for s in x.get_layout().stride] # type: ignore[misc] 238 else: 239 stride = FlexibleLayout.contiguous_strides(size) # type: ignore[arg-type] 240 dtype = x.get_dtype() 241 device = x.get_device() 242 size = convert_shape_to_symint(size) 243 stride = convert_shape_to_symint(stride) 244 t = torch.empty_strided( 245 size=size, stride=stride, dtype=dtype, device=device 246 ).zero_() 247 return t 248 249 250def may_convert_to_optional(value): 251 if isinstance(value, list) and not value: 252 # [None] makes sure the cpp wrapper codegen will generate something like 253 # {c10::nullopt} instead of {} 254 return [None] 255 return value 256 257 258def get_device_type(x): 259 if getattr(x, "get_device", None): 260 return get_device_type(x.get_device()) 261 if isinstance(x, torch.device): 262 return x.type 263 return None 264 265 266def is_triton(x): 267 return is_gpu(get_device_type(x)) 268 269 270def is_cpu(x): 271 return get_device_type(x) == "cpu" 272 273 274class IRNode: 275 _current_origins: ClassVar[Set[Any]] = set() 276 277 @staticmethod 278 @contextlib.contextmanager 279 def current_origins(origins: Set[torch.fx.Node]): 280 old = IRNode._current_origins 281 IRNode._current_origins = old | origins 282 try: 283 yield 284 finally: 285 IRNode._current_origins = old 286 287 def __post_init__(self): 288 self.origins = set(self._current_origins) 289 self.traceback = traceback.format_stack() if config.debug_ir_traceback else None 290 291 def get_traceback(self): 292 return self.traceback 293 294 def common_repr(self): 295 origins = f"origins={getattr(self, 'origins', '')}" 296 if len(origins) > 64: 297 # this can get *very* long 298 origins = f"{origins[:61]}..." 299 return [origins] 300 301 def str_helper(self, lines): 302 lines = lines + self.common_repr() 303 lines = indent(",\n".join(map(str, lines))) 304 return f"{type(self).__name__}(\n{lines}\n)" 305 306 def is_user_of(self, name): 307 return name in self.get_read_names() 308 309 @cache_on_self 310 def get_read_names(self): 311 return {dep.name for dep in self.get_reads()} 312 313 def get_dtype(self): 314 return self.dtype 315 316 def get_layout(self): 317 raise NotImplementedError(f"get_layout() is not implemented by {type(self)}!") 318 319 def get_size(self): 320 raise NotImplementedError(f"get_size() is not implemented by {type(self)}!") 321 322 def get_numel(self): 323 return sympy_product(self.get_size()) 324 325 def is_zero_elements(self): 326 return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0)) # type: ignore[arg-type] 327 328 def realize(self): 329 """ 330 If the IRNode refers to data which has not been materialized (e.g., 331 it is a Pointwise/Reduction that could potentially have more 332 compute fused into it), realize the IRNode into physical memory, 333 ending the possibility of fusing into it, but allowing, e.g., multiple 334 users to access the data without having to recompute. 335 336 Check StorageBox.realize for a particularly notable implementation. 337 338 TODO(ezyang): I think, in principle, every IRNode should have an 339 implementation of this, and most of the time no-op is OK, but you 340 really do have to audit each IRNode for this, so for now, raise 341 an error if it's not implemented. Note that some code in graph.py 342 will catch this thrown error and suppress it with a warning. 343 """ 344 raise NotImplementedError(f"realize NYI on {type(self)}") 345 346 def codegen_reference(self, writer=None): 347 raise NotImplementedError(f"codegen_reference NYI on {type(self)}") 348 349 # The abstract method declarations below serve to convince mypy that all IRNode instances have these functions 350 # defined, while having no effect at runtime. We cannot create stub implementations here because other parts of 351 # the code dynamically check for defined attributes. 352 get_device: Callable[[], torch.device] 353 dtype: torch.dtype 354 get_name: Callable[[], str] 355 get_reads: Callable[[], Any] 356 get_stride: Callable[[], Any] 357 get_storage_numel: Callable[[], Any] 358 has_exceeded_max_reads: Callable[[], bool] 359 make_loader: Callable[[], Callable[[Any], Any]] 360 make_indexer: Callable[[], Callable[[Any], Any]] 361 mark_reuse: Callable[[int], None] 362 realize_hint: Callable[[], None] 363 get_unbacked_symbol_uses: Callable[[], Set[sympy.Symbol]] 364 365 366@dataclasses.dataclass 367class Loops(IRNode): 368 device: torch.device 369 dtype: torch.dtype 370 inner_fn: Callable[..., Any] 371 ranges: List[Expr] 372 373 def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: 374 return set().union( 375 *(free_unbacked_symbols(e) for e in self.ranges), 376 self.inner_fn_free_unbacked_symbols(), 377 ) 378 379 def __str__(self, names=("ranges",)): 380 return self.str_helper( 381 [ 382 f"'{self.device.type}'", 383 str(self.dtype), 384 self.inner_fn_str(), 385 ] 386 + [f"{name}={getattr(self, name)}" for name in names] 387 + [f"origin_node={self.origin_node!r}"] 388 ) 389 390 def __post_init__(self): 391 super().__post_init__() 392 self.origin_node = None 393 394 __repr__ = __str__ 395 396 def get_device(self): 397 return self.device 398 399 def get_origin_node(self): 400 return self.origin_node 401 402 def get_size(self): 403 return self.ranges 404 405 def get_pointwise_size(self): 406 return self.ranges 407 408 def is_extern(self): 409 return False 410 411 @classmethod 412 def create(cls, *args, **kwargs): 413 origin_node = kwargs.pop("origin_node", None) 414 tb = kwargs.pop("traceback", None) 415 r = cls(*args, **kwargs) 416 r.origin_node = origin_node 417 r.traceback = ( 418 tb or traceback.format_stack() if config.debug_ir_traceback else None 419 ) 420 return TensorBox.create(r) 421 422 @staticmethod 423 def _index(ranges, prefix=SymT.INDEX): 424 return [ 425 sympy.Integer(0) if s == 1 else sympy_index_symbol_with_prefix(prefix, n) 426 for n, s in enumerate(ranges) 427 ] 428 429 @cache_on_self 430 def inner_fn_opcount(self): 431 opcounter = OpCounterCSE(V.MockHandler()) 432 433 with V.set_ops_handler(opcounter), patch.object( 434 FlexibleLayout, "allow_indexing", True 435 ): 436 self.inner_fn(*self.inner_fn_args()) 437 return opcounter.op_count 438 439 def inner_fn_args(self): 440 return (self._index(self.ranges),) 441 442 def inner_fn_str(self): 443 return V.KernelFormatterHandler.ir_to_string( 444 self.inner_fn, *self.inner_fn_args() 445 ) 446 447 def has_large_inner_fn(self): 448 return self.inner_fn_opcount() > config.realize_opcount_threshold 449 450 def inner_fn_free_unbacked_symbols(self): 451 index = self._index(self.ranges) 452 return extract_free_unbacked_symbols(self.inner_fn, index) 453 454 def get_reads(self): 455 with patch.object(FlexibleLayout, "allow_indexing", True): 456 if self.get_reduction_type(): 457 return extract_read_writes( 458 self.make_loader(), 459 self.get_size(), 460 self.get_reduction_size(), 461 ).reads 462 else: 463 return extract_read_writes( 464 self.make_loader(), 465 self.get_size(), 466 ).reads 467 468 def get_reduction_size(self): 469 raise NotImplementedError( 470 f"get_reduction_size() is not implemented by {type(self)}!" 471 ) 472 473 def get_reduction_type(self): 474 raise NotImplementedError( 475 f"get_reduction_type() is not implemented by {type(self)}!" 476 ) 477 478 def constant_to_device(self, device): 479 raise NotImplementedError( 480 f"constant_to_device() is not implemented by {type(self)}!" 481 ) 482 483 484def nop_loader_fn(idx, *, dtype): 485 if dtype.is_floating_point: 486 return ops.constant(float("nan"), dtype) 487 else: 488 return ops.constant(0, dtype) 489 490 491class Pointwise(Loops): 492 def make_loader(self): 493 # Make zero-element loops into a no-op 494 if self.is_zero_elements(): 495 return partial(nop_loader_fn, dtype=self.dtype) 496 497 return self.inner_fn 498 499 def get_reduction_size(self): 500 return [] 501 502 def get_reduction_type(self): 503 return None 504 505 def store_output(self, output_name, indexer, vars): 506 loader = self.make_loader() 507 return ops.store(output_name, indexer(vars), loader(vars)) 508 509 def constant_to_device(self, device): 510 """Move this to a given device. Requires that all reads are to constants.""" 511 loader = self.make_loader() 512 loader = patch.object(ConstantBuffer, "override_device", device)(loader) 513 return Pointwise(device, self.dtype, loader, self.ranges) 514 515 516@dataclasses.dataclass 517class Scatter(Pointwise): 518 output_indexer: Callable[[List[Expr]], Expr] 519 scatter_mode: Optional[str] = None 520 521 def constant_to_device(self, device): 522 """Move this to a given device. Requires that all reads are to constants.""" 523 loader = self.make_loader() 524 loader = patch.object(ConstantBuffer, "override_device", device)(loader) 525 return Scatter( 526 device, 527 self.dtype, 528 loader, 529 self.ranges, 530 self.output_indexer, 531 self.scatter_mode, 532 ) 533 534 def store_output(self, output_name, indexer, vars): 535 loader = self.make_loader() 536 return ops.store( 537 output_name, 538 indexer(self.output_indexer(vars)), 539 loader(vars), 540 mode=self.scatter_mode, 541 ) 542 543 544REDUCTION_COMBINE_FN = { 545 "any": ops_wrapper("logical_or"), 546 "max": ops_wrapper("maximum"), 547 "min": ops_wrapper("minimum"), 548 "prod": ops_wrapper("mul"), 549 "sum": ops_wrapper("add"), 550 "xor_sum": ops_wrapper("bitwise_xor"), 551} 552 553 554def get_reduction_combine_fn(reduction_type, dtype, arg_break_ties_left=True): 555 if reduction_type in REDUCTION_COMBINE_FN: 556 combine_fn = REDUCTION_COMBINE_FN[reduction_type] 557 elif reduction_type in {"argmax", "argmin"}: 558 559 def combine_fn(a, b): 560 a_value, a_index = a 561 b_value, b_index = b 562 563 if reduction_type == "argmin": 564 mask = ops.lt(a_value, b_value) 565 else: 566 mask = ops.gt(a_value, b_value) 567 568 equal = ops.eq(a_value, b_value) 569 if is_float_dtype(dtype): 570 a_isnan = ops.ne(a_value, a_value) 571 b_isnan = ops.ne(b_value, b_value) 572 mask = ops.logical_or(mask, ops.gt(a_isnan, b_isnan)) 573 equal = ops.logical_or(equal, ops.logical_and(a_isnan, b_isnan)) 574 575 tie = ( 576 ops.lt(a_index, b_index) 577 if arg_break_ties_left 578 else ops.gt(a_index, b_index) 579 ) 580 mask = ops.logical_or(mask, ops.logical_and(equal, tie)) 581 return ( 582 ops.where(mask, a_value, b_value), 583 ops.where(mask, a_index, b_index), 584 ) 585 586 elif reduction_type == "welford_combine": 587 588 def combine_fn(a, b): 589 a_mean, a_m2, a_weight = a 590 b_mean, b_m2, b_weight = b 591 592 delta = b_mean - a_mean 593 new_weight = a_weight + b_weight 594 w2_over_w = b_weight / new_weight 595 return ( 596 a_mean + delta * w2_over_w, 597 a_m2 + b_m2 + delta * delta * a_weight * w2_over_w, 598 new_weight, 599 ) 600 601 else: 602 raise NotImplementedError(f"unknown reduction_type={reduction_type}") 603 604 return combine_fn 605 606 607@dataclasses.dataclass 608class Reduction(Loops): 609 reduction_ranges: List[Expr] 610 reduction_type: str 611 # self.dtype represents the dst dtype 612 src_dtype: torch.dtype 613 reduction_hint: ReductionHint 614 615 def __str__(self): 616 return Loops.__str__( # type: ignore[call-arg] 617 self, names=("ranges", "reduction_ranges", "reduction_type") 618 ) 619 620 def __repr__(self): 621 return self.__str__() 622 623 def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: 624 return super().get_unbacked_symbol_uses() | set().union( 625 *(free_unbacked_symbols(e) for e in self.reduction_ranges) 626 ) 627 628 def get_reduction_size(self): 629 return self.reduction_ranges 630 631 def get_reduction_type(self): 632 return self.reduction_type 633 634 def store_reduction(self, output_name, indexer, vars, reduction_vars): 635 value = ops.reduction( 636 self.dtype, 637 self.src_dtype, 638 self.reduction_type, 639 self.inner_fn(vars, reduction_vars), 640 ) 641 return ops.store_reduction(output_name, indexer(vars), value) 642 643 def index_length(self): 644 return len(self.ranges) + len(self.reduction_ranges) 645 646 def inner_fn_args(self): 647 index = self._index(self.ranges) 648 rindex = self._index(self.reduction_ranges, SymT.RINDEX) 649 return (index, rindex) 650 651 def inner_fn_free_unbacked_symbols(self): 652 index = self._index(self.ranges) 653 rindex = self._index(self.reduction_ranges, SymT.RINDEX) 654 return extract_free_unbacked_symbols(self.inner_fn, index, rindex) 655 656 def constant_to_device(self, device): 657 """Move this to a given device. Requires that all reads are to constants.""" 658 loader = self.make_loader() 659 loader = patch.object(ConstantBuffer, "override_device", device)(loader) 660 return Reduction( 661 device, 662 self.dtype, 663 loader, 664 self.ranges, 665 self.reduction_ranges, 666 self.reduction_type, 667 self.src_dtype, 668 ReductionHint.DEFAULT, 669 ) 670 671 @staticmethod 672 def num_splits( 673 device, 674 dst_dtype, 675 src_dtype, 676 inner_fn, 677 ranges, 678 reduction_ranges, 679 reduction_type, 680 reduction_numel, 681 input_node: Optional[IRNode] = None, 682 ): 683 def _is_static(x): 684 return isinstance(x, (int, sympy.Integer)) 685 686 reduction_numel_hint = V.graph.sizevars.symbolic_hint(reduction_numel) 687 numel_hint = V.graph.sizevars.symbolic_hint(sympy_product(ranges)) 688 689 should_split = ( 690 is_gpu(get_device_type(device)) 691 and reduction_type 692 not in { 693 "argmax", 694 "argmin", 695 } 696 and config.split_reductions 697 # We don't support unbacked symints 698 and _is_static(reduction_numel_hint) 699 and _is_static(numel_hint) 700 ) 701 if not should_split: 702 return ReductionHint.DEFAULT, 1 703 704 device_interface = get_interface_for_device(get_device_type(device)) 705 device_properties = device_interface.Worker.get_device_properties(device) 706 if get_device_type(device) == "xpu": 707 num_sm = device_properties.gpu_subslice_count 708 else: 709 # default is cuda behavior 710 num_sm = device_properties.multi_processor_count 711 712 min_elements_per_thread = 32 713 max_elements_per_thread = 512 714 threads_per_sm = 2048 715 min_elements_per_device = min_elements_per_thread * num_sm * threads_per_sm 716 max_elements_per_device = max_elements_per_thread * num_sm * threads_per_sm 717 718 def inner_reduction_splits(reduction_numel_hint, numel_hint): 719 # do heuristics that's close to eager mode for split inner reduction 720 # we leak reduction autotune configs here, and will need to refactor to avoid this later 721 num_warps = 8 722 num_threads = 32 * num_warps 723 if numel_hint >= 2 * num_sm: # don't split if there are enough outputs 724 return 1 725 if reduction_numel_hint <= 8192: 726 return 1 727 if reduction_numel_hint * numel_hint <= min_elements_per_device: 728 split_size = min_elements_per_thread 729 elif reduction_numel_hint * numel_hint < max_elements_per_device: 730 target_blocks = num_sm * threads_per_sm // (2 * num_threads) 731 blocks_per_output = (target_blocks + numel_hint - 1) // numel_hint 732 tmp_split_size = ( 733 reduction_numel_hint + num_threads * blocks_per_output - 1 734 ) // (num_threads * blocks_per_output) 735 divisors = sympy.divisors(reduction_numel_hint) 736 closest = min(divisors, key=lambda x: abs(x - tmp_split_size)) 737 if abs(closest - tmp_split_size) < 30: 738 # prefer even splits, but never smalle than min_elements_per_thread 739 split_size = max(closest, min_elements_per_thread) 740 else: 741 split_size = tmp_split_size 742 else: 743 divisors = sympy.divisors(reduction_numel_hint) 744 closest = min(divisors, key=lambda x: abs(x - max_elements_per_thread)) 745 if abs(closest - max_elements_per_thread) < 50: 746 # prefer even splits 747 split_size = closest 748 else: 749 split_size = max_elements_per_thread 750 return (reduction_numel_hint + split_size * num_threads - 1) // ( 751 split_size * num_threads 752 ) 753 754 def outer_reduction_splits(reduction_numel_hint, numel_hint): 755 # TODO the best heuristic currently has XBLOCK (corresponding to numel_hint) 128 756 # extend to even smaller number of outputs 757 num_warps = 8 758 num_threads = num_warps * 32 759 rvals_per_thread = 4 # comes from heuristics, refactor to not leak here 760 xvals_per_block = 128 761 xblocks = (numel_hint + xvals_per_block - 1) // xvals_per_block 762 if reduction_numel_hint * numel_hint < min_elements_per_device: 763 split_size = min_elements_per_thread 764 elif reduction_numel_hint * numel_hint < max_elements_per_device: 765 target_blocks = num_sm * threads_per_sm // (num_threads) 766 target_blocks = (target_blocks + xblocks - 1) // xblocks 767 tmp_split_size = ( 768 reduction_numel_hint + rvals_per_thread * target_blocks - 1 769 ) // (rvals_per_thread * target_blocks) 770 divisors = sympy.divisors(reduction_numel_hint) 771 closest = min(divisors, key=lambda x: abs(x - tmp_split_size)) 772 if abs(tmp_split_size - closest) < 20: 773 split_size = max(closest, min_elements_per_thread) 774 else: 775 split_size = tmp_split_size 776 else: 777 divisors = sympy.divisors(reduction_numel_hint) 778 closest = min(divisors, key=lambda x: abs(x - max_elements_per_thread)) 779 if abs(closest - max_elements_per_thread) < 50: 780 # prefer even splits 781 split_size = closest 782 else: 783 split_size = max_elements_per_thread 784 785 return (reduction_numel_hint + rvals_per_thread * split_size - 1) // ( 786 rvals_per_thread * split_size 787 ) 788 789 # easy cases 790 if numel_hint == 1: 791 split = inner_reduction_splits(reduction_numel_hint, numel_hint) 792 if split == 1: 793 # No need to split. 794 return ReductionHint.INNER, split 795 if input_node is not None and isinstance(input_node, TensorBox): 796 new_ranges, new_reduction_ranges = extract_input_node_reduction_ranges( 797 input_node 798 ) 799 if new_ranges is not None and new_reduction_ranges is not None: 800 extracted_numel_hint = V.graph.sizevars.symbolic_hint( 801 sympy_product(new_ranges + new_reduction_ranges) 802 ) 803 if reduction_numel_hint == extracted_numel_hint: 804 log.debug( 805 "Use previous IRNode's range and reduction_ranges instead of split. " 806 "current ranges: %s, current reduction ranges: %s, current split: %d, " 807 "new ranges: %s, new reduction ranges: %s", 808 ranges, 809 reduction_ranges, 810 split, 811 new_ranges, 812 new_reduction_ranges, 813 ) 814 # If the input_node or its dependent nodes are also Reduction nodes, 815 # use reduction_sizes of this node or its dependent nodes directly. 816 return ReductionHint.INNER, -1 817 return ReductionHint.INNER, split 818 if ( 819 reduction_numel_hint <= min_elements_per_thread 820 or numel_hint >= num_sm * 2 * 32 821 ): 822 return ReductionHint.DEFAULT, 1 823 824 r = Reduction( 825 device, 826 dst_dtype, 827 inner_fn, 828 ranges, 829 reduction_ranges, 830 reduction_type, 831 src_dtype, 832 ReductionHint.DEFAULT, 833 ) 834 835 def get_read_indices(r): 836 cb = ComputedBuffer( 837 name=None, 838 layout=FlexibleLayout( 839 device=r.get_device(), 840 dtype=r.get_dtype(), 841 size=r.get_size(), 842 ), 843 data=r, 844 ) 845 read_writes = cb.get_read_writes() 846 # try finding the full size producer 847 # TODO this will fail for something like ((1, N) * (N, 1)).sum() 848 # this would also possibly be wrong for producers with the different contiguity but we hope those cases are rare 849 range_vars = [ 850 r 851 for r in read_writes.range_vars 852 if isinstance(r, sympy.Expr) and not isinstance(r, sympy.Number) 853 ] 854 indices = [] 855 changed = False 856 for md in sorted(read_writes.reads, key=lambda x: x.name): 857 if all(r in md.index.free_symbols for r in range_vars): 858 indices.append(md.index) 859 if md.name in V.graph.name_to_buffer: 860 buf = V.graph.name_to_buffer[md.name] 861 original_stride = buf.layout.stride 862 buf.decide_layout() 863 if buf.layout.stride != original_stride: 864 changed = True 865 return indices, changed 866 867 indices, changed = get_read_indices(r) 868 if changed: 869 indices, _ = get_read_indices(r) 870 871 if len(indices) == 0: 872 # TODO determine splits when all inputs are broadcast 873 return ReductionHint.DEFAULT, 1 874 875 (_, reduction_vars), ranges = dependencies.index_vars_squeeze( 876 r.get_size(), r.get_reduction_size() 877 ) 878 num_outer = 0 879 num_inner = 0 880 for i in indices: 881 i = V.graph.sizevars.simplify_with_ranges(i, ranges) 882 strides = V.graph.sizevars.stride_hints(i, reduction_vars, ranges.keys()) 883 outer = all(s > 1 for s in strides) 884 if outer: 885 num_outer += 1 886 else: 887 num_inner += 1 888 if num_inner > num_outer: 889 return ReductionHint.INNER, inner_reduction_splits( 890 reduction_numel_hint, numel_hint 891 ) 892 else: 893 return ReductionHint.OUTER, outer_reduction_splits( 894 reduction_numel_hint, numel_hint 895 ) 896 897 @staticmethod 898 def _unroll_reduction_fn(inner_fn, reduction_ranges, reduction_type, src_dtype): 899 """Convert inner_fn from a reduction to an pointwise""" 900 reduction_ranges = [ 901 V.graph.sizevars.evaluate_static_shape(x) for x in reduction_ranges 902 ] 903 904 combine_fn = get_reduction_combine_fn(reduction_type, src_dtype) 905 906 def fn(index): 907 return functools.reduce( 908 combine_fn, 909 ( 910 value_fn(index, rindex) 911 for rindex in itertools.product( 912 *[range(x) for x in reduction_ranges] 913 ) 914 ), 915 ) 916 917 if reduction_type in ("argmin", "argmax"): 918 flatten_index = FixedLayout( 919 None, # type: ignore[arg-type] 920 None, # type: ignore[arg-type] 921 reduction_ranges, 922 FlexibleLayout.contiguous_strides(reduction_ranges), 923 ).make_indexer() 924 925 def value_fn(index, rindex): 926 rindex = [sympy.expand(i) for i in rindex] 927 return ( 928 inner_fn(index, rindex), 929 ops.index_expr(flatten_index(rindex), torch.int64), 930 ) 931 932 return lambda index: fn(index)[1] 933 else: 934 value_fn = inner_fn 935 return fn 936 937 @classmethod 938 def create( # type: ignore[override] 939 cls, 940 device: torch.device, 941 dst_dtype: torch.dtype, 942 src_dtype: torch.dtype, 943 inner_fn: Callable[..., Any], 944 ranges: List[Expr], 945 reduction_ranges: List[Expr], 946 reduction_type: str, 947 reduction_hint: ReductionHint = ReductionHint.DEFAULT, 948 input_node: Optional[IRNode] = None, 949 ): 950 reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges)) 951 952 if reduction_numel == 0: 953 # N.B. This is a hack to generate the literal of the given type 954 # Ideally, we should be fixing `def constant` in triton.py 955 # but it breaks due to hardcoded dtypes in other places 956 def py_cnst(val): 957 return ( 958 bool(val) 959 if dst_dtype == torch.bool 960 else float(val) 961 if dst_dtype.is_floating_point 962 else int(val) 963 ) 964 965 rtypes_to_inits = { 966 "sum": py_cnst(0), 967 "xor_sum": py_cnst(0), 968 "prod": py_cnst(1), 969 "any": py_cnst(0), 970 # "all" is desugared to `!any(!val)` 971 } 972 973 assert ( 974 reduction_type in rtypes_to_inits.keys() 975 ), f"{reduction_type} not supported for zero-dimension tensors!" 976 977 def const_fn(index): 978 return ops.constant(rtypes_to_inits[reduction_type], dst_dtype) 979 980 return Pointwise.create( 981 device=device, 982 dtype=src_dtype, 983 inner_fn=const_fn, 984 ranges=list(ranges), 985 ) 986 987 if reduction_numel == 1: 988 # this reduction is actually a pointwise op 989 if reduction_type in ("argmin", "argmax"): 990 991 def fn(index): 992 return ops.constant(0, dst_dtype) 993 994 else: 995 996 def fn(index): 997 reduction_index = [sympy.Integer(0) for _ in reduction_ranges] 998 return inner_fn(index, reduction_index) 999 1000 return Pointwise.create(device, dst_dtype, fn, ranges) 1001 1002 if ( 1003 isinstance(reduction_numel, sympy.Integer) 1004 and V.graph.sizevars.size_hint(reduction_numel) 1005 < config.unroll_reductions_threshold 1006 and sympy_product(ranges) != 1 1007 ): 1008 return Pointwise.create( 1009 device, 1010 dst_dtype, 1011 cls._unroll_reduction_fn( 1012 inner_fn, reduction_ranges, reduction_type, src_dtype 1013 ), 1014 ranges, 1015 ) 1016 1017 # triton doesn't support reduce to single element well, so break it up 1018 hint, split = cls.num_splits( 1019 device, 1020 dst_dtype, 1021 src_dtype, 1022 inner_fn, 1023 ranges, 1024 reduction_ranges, 1025 reduction_type, 1026 reduction_numel, 1027 input_node, 1028 ) 1029 # intermediate reduction in split can contain complex indexing, 1030 # and num_splits will fail to correctly set the hint 1031 # reuse the passed hint if available 1032 if reduction_hint == ReductionHint.DEFAULT: 1033 reduction_hint = hint 1034 if split == -1: 1035 assert input_node is not None 1036 new_ranges, new_reduction_ranges = extract_input_node_reduction_ranges( 1037 input_node # type: ignore[arg-type] 1038 ) 1039 assert new_ranges is not None 1040 assert new_reduction_ranges is not None 1041 return cls.create_multilayer_existing_ranges( 1042 device, 1043 dst_dtype, 1044 src_dtype, 1045 inner_fn, 1046 ranges, 1047 reduction_ranges, 1048 new_ranges, 1049 new_reduction_ranges, 1050 reduction_type, 1051 reduction_hint, 1052 ) 1053 elif split > 1: 1054 # triton doesn't support reduce to single element well, so break it up 1055 return cls.create_multilayer( 1056 device, 1057 dst_dtype, 1058 src_dtype, 1059 inner_fn, 1060 ranges, 1061 reduction_ranges, 1062 reduction_type, 1063 split, 1064 reduction_hint, 1065 ) 1066 1067 return TensorBox.create( 1068 Reduction( 1069 device, 1070 dst_dtype, 1071 inner_fn, 1072 ranges, 1073 reduction_ranges, 1074 reduction_type, 1075 src_dtype, 1076 reduction_hint, 1077 ) 1078 ) 1079 1080 @staticmethod 1081 def default_accumulator(reduction_type, dtype): 1082 if reduction_type in {"max", "argmax"}: 1083 if is_float_dtype(dtype): 1084 return float("-inf") 1085 elif is_boolean_dtype(dtype): 1086 return 0 1087 else: 1088 return torch.iinfo(dtype).min 1089 if reduction_type in {"min", "argmin"}: 1090 if is_float_dtype(dtype): 1091 return float("inf") 1092 elif is_boolean_dtype(dtype): 1093 return 1 1094 else: 1095 return torch.iinfo(dtype).max 1096 1097 return { 1098 "sum": 0, 1099 "prod": 1, 1100 "xor_sum": 0, 1101 "any": 0, 1102 "welford_reduce": (0, 0, 0), 1103 "welford_combine": (0, 0, 0), 1104 }[reduction_type] 1105 1106 @staticmethod 1107 def default_value(reduction_type, dtype): 1108 if reduction_type == "welford_reduce": 1109 return 0 1110 return Reduction.default_accumulator(reduction_type, dtype) 1111 1112 @staticmethod 1113 def _multilayer_second_step_hint( 1114 split: int, numel_hint: int, reduction_hint: ReductionHint 1115 ) -> ReductionHint: 1116 if split == -1: 1117 return reduction_hint 1118 if split <= 512 and numel_hint <= 512 and reduction_hint == ReductionHint.OUTER: 1119 return ReductionHint.OUTER_TINY 1120 if ( 1121 split <= 1024 1122 and numel_hint <= 256 1123 and reduction_hint == ReductionHint.OUTER 1124 ): 1125 return ReductionHint.OUTER_TINY 1126 1127 return reduction_hint 1128 1129 @classmethod 1130 def _multilayer_wrap_loader( 1131 cls, 1132 loader, 1133 reduction_ranges, 1134 reduction_numel, 1135 split, 1136 block_size, 1137 default, 1138 ): 1139 reindex = View.dynamic_reshape_indexer(reduction_ranges, [reduction_numel]) 1140 need_mask = not V.graph.sizevars.is_expr_static_and_true( 1141 sympy.Eq(reduction_numel % split, 0) # type: ignore[arg-type] 1142 ) 1143 1144 def wrapper_fn(index, reduction_index): 1145 (reduction_index,) = reduction_index 1146 *new_index, reduction_block = index 1147 indices = block_size * reduction_block + reduction_index 1148 1149 def body(): 1150 return loader(new_index, reindex([indices])) 1151 1152 if need_mask: 1153 mask = ops.lt( 1154 ops.index_expr(indices, torch.int32), 1155 ops.index_expr(reduction_numel, torch.int32), 1156 ) 1157 return ops.masked(mask, body, default) 1158 else: 1159 return body() 1160 1161 return wrapper_fn 1162 1163 @classmethod 1164 def _multilayer_wrap_loader_existing_ranges( 1165 cls, 1166 loader, 1167 original_ranges, 1168 original_reduction_ranges, 1169 new_ranges, 1170 new_reduction_ranges, 1171 default, 1172 ): 1173 assert all( 1174 r == 1 for r in original_ranges 1175 ), f"Only enabled for numel_hint == 1, found {original_ranges=}" 1176 reindex = View.dynamic_reshape_indexer( 1177 original_reduction_ranges, tuple(new_ranges) + tuple(new_reduction_ranges) 1178 ) 1179 1180 def wrapper_fn(merged_index, new_reduction_index): 1181 original_idx = merged_index[: len(original_ranges)] 1182 new_index = merged_index[len(original_ranges) :] 1183 return loader( 1184 original_idx, 1185 reindex(tuple(new_index) + tuple(new_reduction_index)), 1186 ) 1187 1188 return wrapper_fn 1189 1190 @classmethod 1191 def create_multilayer_helper( 1192 cls, 1193 device: torch.device, 1194 dst_dtype: torch.dtype, 1195 src_dtype: torch.dtype, 1196 wrapper_fn: Callable[..., Any], 1197 original_ranges: List[Expr], 1198 original_reduction_ranges: List[Expr], 1199 new_ranges: List[Expr], 1200 new_reduction_ranges: List[Expr], 1201 reduction_type: str, 1202 split: int, 1203 reduction_hint: ReductionHint, 1204 ): 1205 """ 1206 Break a large reduction up into multiple smaller reductions 1207 recursively 1208 """ 1209 # triton will automatically compute reductions in fp32 if reducing over fp16/bf16 1210 # within the kernel. keep the intermediate in fp32 so as to keep the whole reduction 1211 # in fp32 and not reduce precision by breaking up the kernel into multiple layers 1212 intermediate_dtype = ( 1213 dst_dtype 1214 if dst_dtype not in (torch.float16, torch.bfloat16) 1215 else torch.float 1216 ) 1217 intermediate = Reduction.create( 1218 device, 1219 intermediate_dtype, 1220 src_dtype, 1221 wrapper_fn, 1222 new_ranges, 1223 new_reduction_ranges, 1224 reduction_type, 1225 reduction_hint, 1226 ) 1227 intermediate.realize() 1228 intermediate_loader = intermediate.make_loader() 1229 1230 def intermediate_fn(index, reduction_index): 1231 return intermediate_loader([*index, *reduction_index]) 1232 1233 numel_hint = V.graph.sizevars.size_hint(sympy_product(original_ranges)) 1234 reduction_hint = cls._multilayer_second_step_hint( 1235 split, numel_hint, reduction_hint 1236 ) 1237 1238 assert original_ranges == new_ranges[: len(original_ranges)] 1239 return TensorBox.create( 1240 Reduction( 1241 device, 1242 dst_dtype, 1243 intermediate_fn, 1244 original_ranges, 1245 new_ranges[len(original_ranges) :], 1246 reduction_type, 1247 src_dtype, 1248 reduction_hint, 1249 ) 1250 ) 1251 1252 @classmethod 1253 def create_multilayer( 1254 cls, 1255 device: torch.device, 1256 dst_dtype: torch.dtype, 1257 src_dtype: torch.dtype, 1258 inner_fn: Callable[..., Any], 1259 ranges: List[Expr], 1260 reduction_ranges: List[Expr], 1261 reduction_type: str, 1262 split: int, 1263 reduction_hint: ReductionHint, 1264 ): 1265 """ 1266 Break a large reduction up into multiple smaller reductions 1267 recursively 1268 """ 1269 # TODO(jansel): realize the reduction so we can do dynamic indexing 1270 reduction_numel = sympy_product(reduction_ranges) 1271 block_size = FloorDiv(reduction_numel + (split - 1), split) 1272 default = cls.default_value(reduction_type, dst_dtype) 1273 wrapper_fn = cls._multilayer_wrap_loader( 1274 inner_fn, reduction_ranges, reduction_numel, split, block_size, default 1275 ) 1276 1277 return cls.create_multilayer_helper( 1278 device, 1279 dst_dtype, 1280 src_dtype, 1281 wrapper_fn, 1282 ranges, 1283 reduction_ranges, 1284 [*ranges, split], # type: ignore[list-item] 1285 [block_size], 1286 reduction_type, 1287 split, 1288 reduction_hint, 1289 ) 1290 1291 @classmethod 1292 def create_multilayer_existing_ranges( 1293 cls, 1294 device: torch.device, 1295 dst_dtype: torch.dtype, 1296 src_dtype: torch.dtype, 1297 inner_fn: Callable[..., Any], 1298 original_ranges: List[Expr], 1299 original_reduction_ranges: List[Expr], 1300 new_ranges: List[Expr], 1301 new_reduction_ranges: List[Expr], 1302 reduction_type: str, 1303 reduction_hint: ReductionHint, 1304 ): 1305 """ 1306 Break a large reduction up into multiple smaller reductions 1307 recursively 1308 """ 1309 default = cls.default_value(reduction_type, dst_dtype) 1310 wrapper_fn = cls._multilayer_wrap_loader_existing_ranges( 1311 inner_fn, 1312 original_ranges, 1313 original_reduction_ranges, 1314 new_ranges, 1315 new_reduction_ranges, 1316 default, 1317 ) 1318 return cls.create_multilayer_helper( 1319 device, 1320 dst_dtype, 1321 src_dtype, 1322 wrapper_fn, 1323 original_ranges, 1324 original_reduction_ranges, 1325 [*original_ranges, *new_ranges], 1326 new_reduction_ranges, 1327 reduction_type, 1328 -1, 1329 reduction_hint, 1330 ) 1331 1332 1333def num_reduction_outputs(reduction_type): 1334 return 3 if "welford" in reduction_type else 1 1335 1336 1337class WelfordReduction(Reduction): 1338 output_index: int 1339 1340 def __init__( 1341 self, 1342 device, 1343 dtype, 1344 inner_fns, 1345 ranges, 1346 reduction_ranges, 1347 reduction_type, 1348 reduction_hint, 1349 output_index, 1350 ): 1351 if len(inner_fns) == 1: 1352 loader = inner_fns[0] 1353 else: 1354 1355 def loader(idx, reduction_idx): 1356 return tuple(fn(idx, reduction_idx) for fn in inner_fns) 1357 1358 super().__init__( 1359 device, 1360 dtype, 1361 loader, 1362 ranges, 1363 reduction_ranges, 1364 reduction_type, 1365 dtype, 1366 reduction_hint, 1367 ) 1368 self.output_index = output_index 1369 1370 def store_reduction(self, output_name, indexer, vars, reduction_vars): 1371 values = ops.reduction( 1372 self.dtype, 1373 self.src_dtype, 1374 self.reduction_type, 1375 self.inner_fn(vars, reduction_vars), 1376 ) 1377 value = values[self.output_index] 1378 return ops.store_reduction(output_name, indexer(vars), value) 1379 1380 @classmethod 1381 def create( # type: ignore[override] 1382 cls, 1383 device: torch.device, 1384 dtype: torch.dtype, 1385 inner_fns: Sequence[Callable[..., Any]], 1386 ranges: List[Expr], 1387 reduction_ranges: List[Expr], 1388 reduction_type: str, 1389 reduction_hint: ReductionHint = ReductionHint.DEFAULT, 1390 ): 1391 assert reduction_type in {"welford_reduce", "welford_combine"} 1392 1393 reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges)) 1394 1395 def const(val): 1396 def inner_fn(idx): 1397 return ops.constant( 1398 val, 1399 dtype, 1400 ) 1401 1402 return Pointwise.create( 1403 device=device, 1404 dtype=dtype, 1405 inner_fn=inner_fn, 1406 ranges=list(ranges), 1407 ) 1408 1409 if reduction_numel == 0: 1410 mean = const(0) 1411 m2 = const(0) 1412 weight = const(0) 1413 return mean, m2, weight 1414 1415 if reduction_numel == 1: 1416 1417 def copy(loader): 1418 def inner_fn(idx): 1419 reduction_index = [sympy.Integer(0) for _ in reduction_ranges] 1420 return loader(idx, reduction_index) 1421 1422 return Pointwise.create( 1423 device=device, 1424 dtype=dtype, 1425 inner_fn=inner_fn, 1426 ranges=list(ranges), 1427 ) 1428 1429 if reduction_type == "welford_reduce": 1430 return copy(inner_fns[0]), const(0), const(1) 1431 else: 1432 return tuple(copy(fn) for fn in inner_fns) 1433 1434 # TODO: Unrolled reduction 1435 # if ( 1436 # isinstance(reduction_numel, sympy.Integer) 1437 # and V.graph.sizevars.size_hint(reduction_numel) 1438 # < config.unroll_reductions_threshold 1439 # and sympy_product(ranges) != 1 1440 # ): 1441 # return Pointwise.create( 1442 # device, 1443 # dst_dtype, 1444 # cls._unroll_reduction_fn( 1445 # inner_fn, reduction_ranges, reduction_type, src_dtype 1446 # ), 1447 # ranges, 1448 # ) 1449 1450 # triton doesn't support reduce to single element well, so break it up 1451 hint, split = Reduction.num_splits( 1452 device, 1453 dtype, 1454 dtype, 1455 inner_fns[0], 1456 ranges, 1457 reduction_ranges, 1458 reduction_type=reduction_type, 1459 reduction_numel=reduction_numel, 1460 ) 1461 # intermediate reduction in split can contain complex indexing, 1462 # and num_splits will fail to correctly set the hint 1463 # reuse the passed hint if available 1464 if reduction_hint == ReductionHint.DEFAULT: 1465 reduction_hint = hint 1466 if split > 1: 1467 # triton doesn't support reduce to single element well, so break it up 1468 return cls.create_multilayer( 1469 device, 1470 dtype, 1471 inner_fns, 1472 ranges, 1473 reduction_ranges, 1474 reduction_type, 1475 split, 1476 reduction_hint, 1477 ) 1478 1479 results = [ 1480 TensorBox.create( 1481 WelfordReduction( 1482 device, 1483 dtype, 1484 inner_fns, 1485 ranges, 1486 reduction_ranges, 1487 reduction_type, 1488 reduction_hint, 1489 output_idx, 1490 ) 1491 ) 1492 for output_idx in range(3) 1493 ] 1494 for t in results: 1495 t.realize() 1496 return results 1497 1498 @staticmethod 1499 def default_value(reduction_type, dtype): 1500 return (0, 0, 0) 1501 1502 @classmethod 1503 def create_multilayer( # type: ignore[override] 1504 cls, 1505 device: torch.device, 1506 dtype: torch.dtype, 1507 inner_fns: Sequence[Callable[..., Any]], 1508 ranges: List[Expr], 1509 reduction_ranges: List[Expr], 1510 reduction_type: str, 1511 split: int, 1512 reduction_hint: ReductionHint, 1513 ): 1514 """ 1515 Break a large reduction up into multiple smaller reductions 1516 recursively 1517 """ 1518 reduction_numel = sympy_product(reduction_ranges) 1519 need_mask = not V.graph.sizevars.is_expr_static_and_true( 1520 sympy.Eq(reduction_numel % split, 0) # type: ignore[arg-type] 1521 ) 1522 1523 if need_mask and reduction_type != "welford_combine": 1524 # If we need mask, then "welford_reduce" doesn't work because 1525 # masked inputs shouldn't count towards the welford weight 1526 1527 def constant(idx, reduction_idx, value): 1528 return ops.constant(value, dtype) 1529 1530 return cls.create_multilayer( 1531 device=device, 1532 dtype=dtype, 1533 inner_fns=( 1534 inner_fns[0], 1535 partial(constant, value=0), 1536 partial(constant, value=1), 1537 ), 1538 ranges=ranges, 1539 reduction_ranges=reduction_ranges, 1540 reduction_type="welford_combine", 1541 split=split, 1542 reduction_hint=reduction_hint, 1543 ) 1544 1545 block_size = FloorDiv(reduction_numel + (split - 1), split) 1546 intermediates = WelfordReduction.create( 1547 device, 1548 dtype, 1549 tuple( 1550 cls._multilayer_wrap_loader( 1551 loader, 1552 reduction_ranges, 1553 reduction_numel, 1554 split, 1555 block_size, 1556 default=0, 1557 ) 1558 for loader in inner_fns 1559 ), 1560 [*ranges, split], # type: ignore[list-item] 1561 [block_size], 1562 reduction_type, 1563 reduction_hint, 1564 ) 1565 for i in intermediates: 1566 i.realize() 1567 1568 i_loaders = [i.make_loader() for i in intermediates] 1569 1570 def intermediate_loader_fn(index, reduction_index, loader): 1571 return loader([*index, *reduction_index]) 1572 1573 numel_hint = V.graph.sizevars.size_hint(sympy_product(ranges)) 1574 reduction_hint = cls._multilayer_second_step_hint( 1575 split, numel_hint, reduction_hint 1576 ) 1577 return WelfordReduction.create( 1578 device, 1579 dtype, 1580 tuple( 1581 partial(intermediate_loader_fn, loader=i.make_loader()) 1582 for i in intermediates 1583 ), 1584 ranges, 1585 [split], # type: ignore[list-item] 1586 # welford_reduce turns one input into three outputs, which are combined with welford_combine 1587 "welford_combine", 1588 reduction_hint, 1589 ) 1590 1591 1592@dataclasses.dataclass 1593class Scan(Loops): 1594 scan_ranges: List[Expr] 1595 size: List[Expr] 1596 combine_fn: Callable[[Tuple[Any, ...], Tuple[Any, ...]], Tuple[Any, ...]] 1597 reindex: Callable[[List[Expr], List[Expr]], List[Expr]] 1598 reduction_hint: ReductionHint 1599 output_index: int 1600 # output_index indexes the following tuples 1601 dtypes: Tuple[torch.dtype, ...] 1602 inner_fns: Tuple[Callable[..., Any], ...] 1603 1604 # HACK we mimick reduction 1605 1606 def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: 1607 # TODO: Can combine_fn/reindex close over unbacked symbols? If so, we 1608 # need to explicitly represent the closure so we can pull out unbacked 1609 # symbols here 1610 return ( 1611 super().get_unbacked_symbol_uses() 1612 | set().union(*(free_unbacked_symbols(e) for e in self.scan_ranges)) 1613 | set().union(*(free_unbacked_symbols(e) for e in self.size)) 1614 ) 1615 1616 def __post_init__(self): 1617 assert len(self.ranges) + len(self.scan_ranges) == len(self.size) 1618 super().__post_init__() 1619 1620 def store_reduction(self, output_name, indexer, vars, scan_vars): 1621 idx = self.reindex(vars, scan_vars) 1622 values = [inner_fn(idx) for inner_fn in self.inner_fns] 1623 result = ops.scan(self.dtypes, self.combine_fn, values) 1624 return ops.store(output_name, indexer(idx), result[self.output_index]) 1625 1626 def get_reduction_type(self): 1627 # return self.scan_op 1628 return "custom" 1629 1630 def get_reduction_size(self): 1631 return self.scan_ranges 1632 1633 def get_size(self): 1634 return self.size 1635 1636 def get_pointwise_size(self): 1637 return self.ranges 1638 1639 def index_length(self): 1640 return len(self.ranges) + len(self.scan_ranges) 1641 1642 def inner_fn_args(self): 1643 index = self._index(self.ranges) 1644 rindex = self._index(self.scan_ranges, SymT.RINDEX) 1645 idx = self.reindex(index, rindex) 1646 return (idx,) 1647 1648 def inner_fn_free_unbacked_symbols(self): 1649 index = self._index(self.ranges) 1650 rindex = self._index(self.scan_ranges, SymT.RINDEX) 1651 idx = self.reindex(index, rindex) 1652 return extract_free_unbacked_symbols(self.inner_fn, idx) 1653 1654 @classmethod 1655 def create( 1656 cls, 1657 device: torch.device, 1658 dtypes: Tuple[torch.dtype, ...], 1659 inner_fns: Tuple[Callable[[List[Expr]], Any], ...], 1660 size: List[Expr], 1661 axis: int, 1662 combine_fn: Callable[[Tuple[Any, ...], Tuple[Any, ...]], Tuple[Any, ...]], 1663 reduction_hint: ReductionHint = ReductionHint.DEFAULT, 1664 **kwargs, 1665 ) -> List[Optional["TensorBox"]]: 1666 pointwise_ranges = [*size[:axis], *size[axis + 1 :]] 1667 scan_ranges = [size[axis]] 1668 1669 if not is_gpu(device.type): 1670 # TODO: CPU support 1671 return [None] * len(dtypes) 1672 1673 if torch.version.hip is not None and len(dtypes) > 1: 1674 # TODO: Remove this when ROCm triton adds support for multiple inputs 1675 return [None] * len(dtypes) 1676 1677 sizevars = V.graph.sizevars 1678 scan_numel = sizevars.simplify(sympy_product(scan_ranges)) 1679 1680 assert len(dtypes) == len(inner_fns) 1681 1682 # Scan with a single element is just a copy 1683 if sizevars.is_expr_static_and_true(sympy.Le(scan_numel, 1)): # type: ignore[arg-type] 1684 return [ 1685 Pointwise.create( 1686 device=device, 1687 dtype=dtypes[output_index], 1688 inner_fn=inner_fns[output_index], 1689 ranges=size, 1690 ) 1691 for output_index in range(len(dtypes)) 1692 ] 1693 1694 reduction_hint, num_splits = cls.num_splits( 1695 device=device, 1696 dtype=dtypes[0], 1697 inner_fn=inner_fns[0], 1698 axis=axis, 1699 pointwise_ranges=pointwise_ranges, 1700 scan_ranges=scan_ranges, 1701 combine_fn=combine_fn, 1702 scan_numel=scan_numel, 1703 ) 1704 scan_type = Scan if num_splits <= 1 else SplitScan 1705 1706 if num_splits > 1 and torch.version.hip is not None: 1707 # Fallback for split-scan on ROCm 1708 return [None] * len(dtypes) 1709 1710 if num_splits > 1 and len(dtypes) > 1: 1711 # Fallback for split-scans for multiple inputs 1712 return [None] * len(dtypes) 1713 1714 def reindex(index, scan_index): 1715 assert len(scan_index) == len(scan_ranges) 1716 assert len(index) == len(pointwise_ranges) 1717 return [*index[:axis], *scan_index, *index[axis:]] 1718 1719 results = [ 1720 TensorBox.create( 1721 scan_type( 1722 device=device, 1723 dtype=dtypes[output_index], 1724 dtypes=dtypes, 1725 inner_fn=inner_fns[output_index], 1726 inner_fns=inner_fns, 1727 size=size, 1728 ranges=pointwise_ranges, 1729 scan_ranges=scan_ranges, 1730 combine_fn=combine_fn, 1731 reindex=reindex, 1732 reduction_hint=reduction_hint, 1733 output_index=output_index, 1734 **kwargs, 1735 ) 1736 ) 1737 for output_index in range(len(dtypes)) 1738 ] 1739 1740 for result in results: 1741 result.realize() 1742 1743 return results 1744 1745 @classmethod 1746 def num_splits( 1747 cls, 1748 device: torch.device, 1749 dtype: torch.dtype, 1750 inner_fn: Callable[[List[Expr]], Any], 1751 axis: int, 1752 pointwise_ranges: List[Expr], 1753 scan_ranges: List[Expr], 1754 combine_fn: Callable[[Tuple[Any, ...], Tuple[Any, ...]], Tuple[Any, ...]], 1755 scan_numel: Expr, 1756 ): 1757 # TODO: custom splitting heuristic for scan 1758 def wrapper_fn(idx, reduction_idx): 1759 return inner_fn([*idx[:axis], *reduction_idx, *idx[axis:]]) 1760 1761 return Reduction.num_splits( 1762 device=device, 1763 dst_dtype=dtype, 1764 src_dtype=dtype, 1765 inner_fn=wrapper_fn, 1766 ranges=pointwise_ranges, 1767 reduction_ranges=scan_ranges, 1768 reduction_type="sum", 1769 reduction_numel=scan_numel, 1770 ) 1771 1772 1773# This signifies a scan op that should go through TritonSplitScanKernel codegen on CUDA. 1774@dataclasses.dataclass 1775class SplitScan(Scan): 1776 pass 1777 1778 1779def is_storage_and_layout(x): 1780 try: 1781 as_storage_and_layout(x, freeze=False) 1782 return True 1783 except NotImplementedError: 1784 return False 1785 1786 1787def is_contiguous_storage_and_layout(x): 1788 try: 1789 buffer, layout = as_storage_and_layout(x, freeze=False) 1790 # pad the stride here so we will NOT claim an tensor as contiguous 1791 # if a padding is gonna happen. 1792 if layout.should_pad_strides(): 1793 layout.pad_strides() 1794 return layout.is_contiguous() 1795 except NotImplementedError: 1796 return False 1797 1798 1799def as_storage_and_layout( 1800 x, freeze=True, want_contiguous=False, stride_order=None, allow_padding=False 1801): 1802 """ 1803 Try to simplify x into a StorageBox and a Layout. 1804 1805 allow_padding only affect how we apply stride_order. When allow_padding 1806 is True, we have the freedom to add padding when applying the stride_order. 1807 """ 1808 if isinstance(x, TensorBox): 1809 return as_storage_and_layout( 1810 x.data, 1811 freeze=freeze, 1812 want_contiguous=want_contiguous, 1813 stride_order=stride_order, 1814 allow_padding=allow_padding, 1815 ) 1816 if isinstance(x, StorageBox) and isinstance(x.data, Buffer): 1817 if freeze: 1818 if want_contiguous: 1819 x.data.freeze_layout() 1820 assert x.data.layout.is_contiguous() 1821 elif stride_order is not None: 1822 x.data.freeze_layout_with_stride_order( 1823 stride_order, allow_padding=allow_padding 1824 ) 1825 else: 1826 x.data.decide_layout() 1827 return x, x.data.layout 1828 if isinstance(x, ReinterpretView): 1829 # making the base of x contiguous or stride_ordered will not necessarily make 1830 # the ReinterpretView either, so don't pass along those arguments 1831 buffer, _ = as_storage_and_layout( 1832 x.data, 1833 freeze=freeze, 1834 ) 1835 return buffer, x.layout 1836 raise NotImplementedError 1837 1838 1839as_contiguous_storage_and_layout = functools.partial( 1840 as_storage_and_layout, want_contiguous=True 1841) 1842 1843 1844def is_stride_order_storage_and_layout(x, stride_order): 1845 try: 1846 buffer, layout = as_storage_and_layout(x, freeze=False) 1847 return layout.is_stride_ordered(stride_order) 1848 except NotImplementedError: 1849 return False 1850 1851 1852@dataclasses.dataclass 1853class BaseView(IRNode): 1854 data: IRNode 1855 1856 def get_unbacked_symbol_uses(self): 1857 return self.data.get_unbacked_symbol_uses() 1858 1859 def make_reindexer(self): 1860 raise NotImplementedError(f"make_reindexer NYI on {self}") 1861 1862 def make_indexer(self): 1863 inner = self.data.make_indexer() 1864 reindex = self.make_reindexer() 1865 1866 def indexer(idx): 1867 return inner(reindex(idx)) 1868 1869 return indexer 1870 1871 def make_loader(self): 1872 inner = self.data.make_loader() 1873 reindex = self.make_reindexer() 1874 1875 def loader(idx): 1876 return inner(reindex(idx)) 1877 1878 return loader 1879 1880 @property 1881 def dtype(self): 1882 return self.data.dtype 1883 1884 def get_layout(self): 1885 return self.data.get_layout() 1886 1887 def get_device(self): 1888 return self.data.get_device() 1889 1890 def get_origin_node(self): 1891 return None 1892 1893 def get_name(self): 1894 return self.data.get_name() 1895 1896 def get_pointwise_size(self): 1897 return self.get_size() 1898 1899 def mark_reuse(self, users): 1900 return self.data.mark_reuse(users) 1901 1902 def has_exceeded_max_reads(self): 1903 return self.data.has_exceeded_max_reads() 1904 1905 def realize(self): 1906 return self.data.realize() 1907 1908 def realize_hint(self): 1909 return self.data.realize_hint() 1910 1911 def get_storage_numel(self): 1912 return self.data.get_storage_numel() 1913 1914 def is_extern(self): 1915 return self.data.is_extern() # type: ignore[attr-defined] 1916 1917 def is_module_buffer(self): 1918 return self.data.is_module_buffer() # type: ignore[attr-defined] 1919 1920 def get_reads(self): 1921 with patch.object(FlexibleLayout, "allow_indexing", True): 1922 return extract_read_writes( 1923 self.make_loader(), 1924 self.get_size(), 1925 ).reads 1926 1927 def unwrap_view(self): 1928 x: IRNode = self 1929 while isinstance(x, BaseView): 1930 x = x.data 1931 return x 1932 1933 def constant_to_device(self, device): 1934 """Move this to a given device. Requires that all reads are to constants.""" 1935 loader = self.make_loader() 1936 loader = patch.object(ConstantBuffer, "override_device", device)(loader) 1937 return Pointwise(device, self.get_dtype(), loader, self.get_size()) 1938 1939 1940@dataclasses.dataclass 1941class ExpandView(BaseView): 1942 size: List[Expr] 1943 1944 @staticmethod 1945 def _normalize_size(x, new_size): 1946 """Replace `-1` with correct sizes""" 1947 sizevars = V.graph.sizevars 1948 new_size = list(map(sympy.expand, new_size)) 1949 old_size = x.get_size() 1950 old_size = [None] * (len(new_size) - len(old_size)) + list(old_size) 1951 assert len(new_size) == len(old_size) 1952 for i in range(len(new_size)): 1953 if new_size[i] == -1: 1954 assert old_size[i] is not None 1955 new_size[i] = old_size[i] 1956 elif old_size[i] is None or old_size[i] == 1: 1957 pass 1958 else: 1959 # Sanity check: Expect broadcast compatibility 1960 # 1961 # NB: new_size[i] == old_size[i] is expected to already be 1962 # guarded because the meta formula was expected to have taught 1963 # us this equality. 1964 assert ( 1965 sizevars.size_hint(new_size[i] - old_size[i], fallback=0) == 0 1966 ), "Broadcast failed in ExpandView({x.get_size()}, {new_size}) on dimension {i}" 1967 return new_size 1968 1969 @classmethod 1970 def create(cls, x, new_size): 1971 new_size = cls._normalize_size(x, new_size) 1972 1973 if is_storage_and_layout(x): 1974 storage, old_layout = as_storage_and_layout(x) 1975 skip = len(new_size) - len(old_layout.size) 1976 assert skip >= 0 1977 new_stride = [sympy.Integer(0)] * skip 1978 for stride, size in zip(old_layout.stride, old_layout.size): 1979 new_stride.append(stride if size != 1 else sympy.Integer(0)) 1980 new_layout = FixedLayout( 1981 old_layout.device, 1982 old_layout.dtype, 1983 list(new_size), 1984 new_stride, 1985 old_layout.offset, 1986 ) 1987 return ReinterpretView(storage, new_layout) 1988 1989 return ExpandView(x, new_size) 1990 1991 def get_size(self): 1992 return self.size 1993 1994 def make_reindexer(self): 1995 target = self.get_size() 1996 actual = self.data.get_size() 1997 skip = len(target) - len(actual) 1998 1999 def reindex(index): 2000 index = list(index[skip:]) 2001 assert len(index) == len(actual) 2002 for i in range(len(actual)): 2003 if actual[i] == 1: 2004 # zero out broadcast dimension 2005 index[i] = sympy.Integer(0) 2006 return index 2007 2008 return reindex 2009 2010 2011@dataclasses.dataclass 2012class PermuteView(BaseView): 2013 dims: List[Expr] 2014 2015 @classmethod 2016 def create(cls, x, dims): 2017 dims = cls._map_neg_dims(dims) 2018 assert set(dims) == set(range(len(dims))) 2019 2020 if is_storage_and_layout(x): 2021 storage, old_layout = as_storage_and_layout(x) 2022 new_layout = FixedLayout( 2023 old_layout.device, 2024 old_layout.dtype, 2025 [old_layout.size[i] for i in dims], 2026 [old_layout.stride[i] for i in dims], 2027 old_layout.offset, 2028 ) 2029 return ReinterpretView(storage, new_layout) 2030 2031 return PermuteView(x, dims) 2032 2033 @classmethod 2034 def _map_neg_dims(cls, dims): 2035 return [dim if dim >= 0 else len(dims) + dim for dim in dims] 2036 2037 def get_size(self): 2038 assert set(self._map_neg_dims(self.dims)) == set(range(len(self.dims))) 2039 size = self.data.get_size() 2040 return [size[i] for i in self.dims] 2041 2042 def make_reindexer(self): 2043 inv = {j: i for i, j in enumerate(self.dims)} 2044 inv = [inv[i] for i in range(len(self.dims))] # type: ignore[index] 2045 assert set(inv) == set(range(len(self.dims))) 2046 2047 def reindex(index): 2048 return [index[i] for i in inv] 2049 2050 return reindex 2051 2052 2053class SqueezeView(BaseView): 2054 @classmethod 2055 def create(cls, x, *, dim=None): 2056 if is_storage_and_layout(x): 2057 storage, old_layout = as_storage_and_layout(x) 2058 new_size = [] 2059 new_stride = [] 2060 if dim is not None: 2061 assert isinstance(dim, int), "expected integer dim argument" 2062 assert 0 <= dim and dim < len(old_layout.size) 2063 2064 for i, (size, stride) in enumerate(zip(old_layout.size, old_layout.stride)): 2065 if dim is None: 2066 if size != 1: 2067 new_size.append(size) 2068 new_stride.append(stride) 2069 else: 2070 if i != dim: 2071 new_size.append(size) 2072 new_stride.append(stride) 2073 else: 2074 assert size == 1, "expected squeezed size to be 1" 2075 2076 new_layout = FixedLayout( 2077 old_layout.device, 2078 old_layout.dtype, 2079 new_size, 2080 new_stride, 2081 old_layout.offset, 2082 ) 2083 return ReinterpretView(storage, new_layout) 2084 2085 if dim is None: 2086 # redirect to a generic view 2087 return View.create(x, [s for s in x.get_size() if s != 1]) 2088 else: 2089 assert x.get_size()[dim] == 1 2090 return View.create(x, [s for i, s in enumerate(x.get_size()) if i != dim]) 2091 2092 @staticmethod 2093 def squeezer(size: Tuple[sympy.Expr, ...]): 2094 new_size = [s for s in size if s != 1] 2095 not_one = [i for i, s in enumerate(size) if s != 1] 2096 length = len(size) 2097 2098 def reindex(index: List[sympy.Expr]) -> Tuple[sympy.Expr, ...]: 2099 assert len(index) == len(not_one), f"{index} {not_one}" 2100 new_index = [sympy.Integer(0)] * length 2101 for idx, s in zip(not_one, index): 2102 new_index[idx] = s 2103 return tuple(new_index) 2104 2105 return new_size, reindex 2106 2107 def __init__(self, data): 2108 raise AssertionError("use SqueezeView.create()") 2109 2110 2111@dataclasses.dataclass 2112class GenericView(BaseView): 2113 size: List[Expr] 2114 reindex: Callable[..., Any] 2115 2116 def make_reindexer(self): 2117 return self.reindex 2118 2119 def reindex_str(self): 2120 index_old = [ 2121 sympy_index_symbol_with_prefix(SymT.INDEX, n) for n in range(len(self.size)) 2122 ] 2123 index_new = list(self.reindex(index_old)) 2124 return f"lambda {', '.join(map(str, index_old))}: {index_new}" 2125 2126 def __str__(self): 2127 return self.str_helper( 2128 [self.data, f"size={self.size}", f"reindex={self.reindex_str()}"] 2129 ) 2130 2131 __repr__ = __str__ 2132 2133 @classmethod 2134 def create(cls, x, new_size, reindex): 2135 return cls(x, list(new_size), reindex) 2136 2137 def get_size(self): 2138 return self.size 2139 2140 2141@dataclasses.dataclass 2142class View(GenericView): 2143 @staticmethod 2144 def handle_negative_index(idx, size): 2145 idx = sympy.expand(idx) 2146 size = sympy.expand(size) 2147 evaluate_expr = V.graph.sizevars.shape_env.evaluate_expr 2148 if evaluate_expr(sympy.Lt(idx, 0)): 2149 idx = idx + size 2150 return idx 2151 2152 @classmethod 2153 def create(cls, x, new_size): 2154 assert isinstance(new_size, (tuple, list)) 2155 old_size, new_size = cls.resolve_negative_size(x.get_size(), new_size) 2156 2157 # Skip pointless views 2158 if V.graph.sizevars.statically_known_list_equals(old_size, new_size): 2159 return x 2160 2161 unbacked_symbols_in_sizes = False 2162 if ( 2163 len(free_unbacked_symbols(old_size)) > 0 2164 or len(free_unbacked_symbols(new_size)) > 0 2165 ): 2166 unbacked_symbols_in_sizes = True 2167 2168 if 0 in new_size: 2169 2170 def fake_reindex(index): 2171 return tuple([0] * len(old_size)) 2172 2173 return cls(x, list(new_size), fake_reindex) 2174 # TODO: a new class for FixedTransferLayout that output layout is constrained by input layout 2175 elif is_contiguous_storage_and_layout(x) or unbacked_symbols_in_sizes: 2176 if unbacked_symbols_in_sizes and (not is_contiguous_storage_and_layout(x)): 2177 # realize x; otherwise, the dynamic_reshape_indexer below will fail 2178 # due to the size_hint's inability to process unbacked SymInts 2179 x = ExternKernel.realize_input(x) 2180 2181 storage, old_layout = as_contiguous_storage_and_layout(x) 2182 new_layout = FixedLayout( 2183 old_layout.device, 2184 old_layout.dtype, 2185 new_size, 2186 FlexibleLayout.contiguous_strides(new_size), 2187 old_layout.offset, 2188 ) 2189 return ReinterpretView(storage, new_layout) 2190 2191 reindex = cls.dynamic_reshape_indexer(old_size, new_size) 2192 return cls(x, list(new_size), reindex) 2193 2194 @staticmethod 2195 def resolve_negative_size(old_size, new_size): 2196 new_size = [V.graph.sizevars.simplify(x) for x in new_size] 2197 old_size = [V.graph.sizevars.simplify(x) for x in old_size] 2198 2199 new_size = list(new_size) 2200 for i in range(len(new_size)): 2201 if new_size[i] == -1: 2202 new_size[i] = sympy.Integer(1) 2203 new_size[i] = CleanDiv(sympy_product(old_size), sympy_product(new_size)) 2204 break 2205 2206 V.graph.sizevars.guard_equals(sympy_product(old_size), sympy_product(new_size)) 2207 return old_size, new_size 2208 2209 @classmethod 2210 def dynamic_reshape_indexer(cls, old_size, new_size): 2211 try: 2212 reindex = cls._dynamic_reshape_indexer(old_size, new_size) 2213 except (AssertionError, IndexError): 2214 # optimistic algorithm failed, lets do a fallback 2215 flat = [sympy_product(old_size)] 2216 reindex1 = cls._dynamic_reshape_indexer(old_size, flat) 2217 reindex2 = cls._dynamic_reshape_indexer(flat, new_size) 2218 reindex = fuse_reindexing(reindex1, reindex2) 2219 return reindex 2220 2221 @staticmethod 2222 def _dynamic_reshape_indexer(old_size, new_size): 2223 """ 2224 Perform a reshape entirely by modifying indexing math 2225 """ 2226 size_hint = V.graph.sizevars.size_hint 2227 # TODO: These symbols may not escape, if they don't assert so and 2228 # treat them as temporary 2229 vars = [ 2230 sympy_index_symbol_with_prefix(SymT.VIEW, i) for i in range(len(new_size)) 2231 ] 2232 2233 stack_new = list(zip(vars, new_size)) 2234 stack_old = list(old_size) 2235 2236 view_expr = [] 2237 while stack_new and stack_old: 2238 size_old = stack_old.pop() 2239 var, size_new = stack_new.pop() 2240 if size_old == 1: 2241 view_expr.append(sympy.Integer(0)) 2242 stack_new.append((var, size_new)) # re-add 2243 elif size_new == 1: 2244 stack_old.append(size_old) # re-add 2245 elif size_hint(size_new) == size_hint(size_old): 2246 view_expr.append(var) 2247 V.graph.sizevars.guard_equals(size_new, size_old) 2248 elif size_hint(size_new) < size_hint(size_old): 2249 while size_hint(size_new) < size_hint(size_old): 2250 var2, size_new2 = stack_new.pop() 2251 var = var2 * size_new + var 2252 size_new = size_new * size_new2 2253 view_expr.append(var) 2254 V.graph.sizevars.guard_equals(size_new, size_old) 2255 elif size_hint(size_new) > size_hint(size_old): 2256 divisor = sympy.Integer(1) 2257 modulus = size_old 2258 view_expr.append(ModularIndexing(var, divisor, modulus)) 2259 divisor = divisor * modulus 2260 while size_hint(size_new) > size_hint(size_old): 2261 modulus = stack_old.pop() 2262 view_expr.append(ModularIndexing(var, divisor, modulus)) 2263 divisor = divisor * modulus 2264 size_old = size_old * modulus 2265 V.graph.sizevars.guard_equals(size_new, size_old) 2266 else: 2267 raise AssertionError 2268 2269 while stack_old: 2270 size_old = stack_old.pop() 2271 V.graph.sizevars.guard_equals(size_old, 1) # type: ignore[arg-type] 2272 view_expr.append(sympy.Integer(0)) 2273 2274 while stack_new: 2275 var, size_new = stack_new.pop() 2276 V.graph.sizevars.guard_equals(size_new, 1) # type: ignore[arg-type] 2277 2278 view_expr.reverse() 2279 assert len(view_expr) == len(old_size) 2280 2281 def reindex(index): 2282 assert len(index) == len(vars), (len(index), len(vars)) 2283 replacements = dict(zip(vars, index)) 2284 return tuple(sympy_subs(x, replacements) for x in view_expr) # type: ignore[arg-type] 2285 2286 return reindex 2287 2288 2289@dataclasses.dataclass 2290class ReinterpretView(BaseView): 2291 """Pretend our storage has a different layout""" 2292 2293 layout: "Layout" 2294 2295 def __post_init__(self): 2296 super().__post_init__() 2297 if isinstance(self.data, BaseView): 2298 self.data = self.data.unwrap_view() 2299 2300 def __str__(self): 2301 return self.str_helper( 2302 [ 2303 self.data, 2304 self.layout, 2305 ] 2306 ) 2307 2308 __repr__ = __str__ 2309 2310 def get_name(self): 2311 return self.data.get_name() 2312 2313 def get_device(self): 2314 return self.layout.device 2315 2316 def get_origin_node(self): 2317 return None 2318 2319 @property 2320 def dtype(self): 2321 return self.layout.dtype 2322 2323 def get_size(self): 2324 return list(self.layout.size) 2325 2326 def get_stride(self): 2327 return list(self.layout.stride) 2328 2329 def make_loader(self): 2330 def loader(index): 2331 indexer = self.layout.make_indexer() 2332 return ops.load(self.get_name(), indexer(index)) 2333 2334 return loader 2335 2336 def make_indexer(self): 2337 return self.layout.make_indexer() 2338 2339 def get_layout(self): 2340 return self.layout 2341 2342 def freeze_layout(self): 2343 pass 2344 2345 def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: 2346 return ( 2347 free_unbacked_symbols(self.layout.size) 2348 | free_unbacked_symbols(self.layout.stride) 2349 | free_unbacked_symbols(self.layout.offset) 2350 ) 2351 2352 def codegen_reference(self, writer=None): 2353 # reinterpret_tensor is similar to as_strided except: 2354 # - offset is added to the existing offset (rather than replacing it) 2355 # - view tracking is disabled similar to unsafe_view 2356 return V.graph.wrapper_code.codegen_reinterpret_view( 2357 self.data, 2358 self.layout.size, 2359 self.layout.stride, 2360 self.layout.offset, 2361 writer, 2362 ) 2363 2364 2365class SliceView(View): 2366 @classmethod 2367 def normalize_start_end(cls, x, dim, start, end): 2368 """ 2369 Normalize start and end such that both are in the range 2370 [0, x.get_size()[dim]] and start <= end. 2371 """ 2372 sizevars = V.graph.sizevars 2373 dim_size = x.get_size()[dim] 2374 2375 if any(free_unbacked_symbols(x) for x in (start, end, dim_size)): 2376 2377 def clamp(x, lower, upper): 2378 return sympy.Min(sympy.Max(x, lower), upper) 2379 2380 else: 2381 2382 def clamp(x, lower, upper): 2383 return sizevars.evaluate_min(sizevars.evaluate_max(x, lower), upper) 2384 2385 def clamp_wrap(val, lower, upper, default): 2386 if val is None: 2387 return default 2388 val = cls.handle_negative_index(val, dim_size) 2389 return clamp(val, lower, upper) 2390 2391 start = clamp_wrap(start, 0, dim_size, 0) 2392 end = clamp_wrap(end, start, dim_size, dim_size) 2393 return start, end 2394 2395 @classmethod 2396 def create(cls, x, dim, start, end, step=1, clamp=True): 2397 step = sympy.expand(step) 2398 assert step > 0 2399 try: 2400 if start == 0 and end >= 2**63 - 1 and step == 1: 2401 return x 2402 except TypeError: 2403 pass 2404 2405 sizevars = V.graph.sizevars 2406 new_size = list(x.get_size()) 2407 2408 # NB: Ordinarily we default to clamping. 2409 # We only don't clamp for split_with_sizes. For split_with_sizes, sizes should be already valid 2410 # failing in this situation is ok, since invalid sizes could trigger silent errors. 2411 if clamp: 2412 start, end = cls.normalize_start_end(x, dim, start, end) 2413 2414 new_size[dim] = FloorDiv(end - start + (step - 1), step) 2415 2416 if is_storage_and_layout(x): 2417 # Fast path 2418 storage, old_layout = as_storage_and_layout(x) 2419 new_stride = list(old_layout.stride) 2420 new_stride[dim] = new_stride[dim] * step 2421 new_layout = FixedLayout( 2422 old_layout.device, 2423 old_layout.dtype, 2424 new_size, 2425 new_stride, 2426 old_layout.offset + old_layout.stride[dim] * start, 2427 ) 2428 return ReinterpretView(storage, new_layout) 2429 2430 def reindex(index): 2431 assert len(index) == len(new_size), f"wrong ndim {index} {new_size}" 2432 index = list(index) 2433 index[dim] = index[dim] * step + start 2434 return index 2435 2436 # redirect to a generic view 2437 return SliceView(x, size=new_size, reindex=reindex) 2438 2439 2440class BaseConstant(IRNode): 2441 dtype: torch.dtype 2442 device: torch.device 2443 2444 def get_size(self): 2445 return () 2446 2447 def get_device(self): 2448 return self.device 2449 2450 def get_origin_node(self): 2451 return None 2452 2453 def mark_reuse(self, users): 2454 pass 2455 2456 def has_exceeded_max_reads(self): 2457 return False 2458 2459 def get_reads(self): 2460 return () 2461 2462 def is_extern(self): 2463 return False 2464 2465 2466@dataclasses.dataclass 2467class Constant(BaseConstant): 2468 value: Any 2469 dtype: torch.dtype 2470 device: torch.device 2471 2472 def make_loader(self): 2473 def loader(index): 2474 return ops.constant(self.value, self.dtype) 2475 2476 return loader 2477 2478 def realize(self): 2479 pass 2480 2481 def constant_to_device(self, device): 2482 return Constant(self.value, self.dtype, device) 2483 2484 2485@dataclasses.dataclass 2486class IndexingConstant(BaseConstant): 2487 index: Any 2488 dtype: torch.dtype 2489 device: torch.device 2490 2491 def make_loader(self): 2492 def loader(index): 2493 return ops.index_expr(self.index, self.dtype) 2494 2495 return loader 2496 2497 def constant_to_device(self, device): 2498 return IndexingConstant(self.index, self.dtype, device) 2499 2500 2501def is_contiguous_strides_for_shape(stride, shape): 2502 return all( 2503 size == 1 or left == right 2504 for left, right, size in zip( 2505 stride, FlexibleLayout.contiguous_strides(shape), shape 2506 ) 2507 ) 2508 2509 2510def get_align_for_dtype(dtype): 2511 """ 2512 CUDA max memory transaction size is 128 bytes for a warp. 2513 We pick `128 // dtype.itemsize` as alighment so GPU can do coalesced 2514 memory access. 2515 """ 2516 return 128 // dtype.itemsize 2517 2518 2519@dataclasses.dataclass 2520class Layout(IRNode): 2521 def __init__( 2522 self, 2523 device: torch.device, 2524 dtype: torch.dtype, 2525 size: List[Expr], 2526 stride: Optional[Sequence[Union[Expr, int]]], 2527 offset: Expr = Integer(0), 2528 ): 2529 assert stride is None or len(size) == len( 2530 stride 2531 ), f"size={size}, stride={stride}" 2532 self.device = device 2533 self.dtype = dtype 2534 assert all(isinstance(s, (Expr, int)) for s in size) 2535 self.size = size 2536 self._stride = stride 2537 self.offset = offset 2538 2539 @property 2540 def stride(self): 2541 return self._stride 2542 2543 def __str__(self): 2544 offset = "" 2545 if self.offset != 0: 2546 offset = f", offset={self.offset}" 2547 return ( 2548 f"{type(self).__name__}('{self.device.type}', {self.dtype}, " 2549 f"size={self.size}, stride={self.stride}{offset})" 2550 ) 2551 2552 __repr__ = __str__ 2553 2554 def is_contiguous(self): 2555 return is_contiguous_strides_for_shape(self.stride, self.size) 2556 2557 @staticmethod 2558 def is_channels_last_contiguous(shape, strides): 2559 ndim = len(shape) 2560 if ndim not in [4, 5] or shape[1] == 1: 2561 return False 2562 for left, right, size in zip( 2563 strides, make_channels_last_strides_for(shape), shape # type: ignore[arg-type] 2564 ): 2565 if size != 1 and left != right: 2566 return False 2567 return True 2568 2569 def is_transposed(self): 2570 for left, right, size in zip( 2571 self.stride, 2572 reversed(FlexibleLayout.contiguous_strides(self.size)), 2573 self.size, 2574 ): 2575 if size != 1 and left != right: 2576 return False 2577 return True 2578 2579 def is_stride_ordered(self, order): 2580 assert len(self.stride) == len(order) 2581 2582 # ignore dimensions of size 1, they dont affect layout 2583 non_1_indices = [ 2584 i 2585 for i, dim in enumerate(self.size) 2586 if V.graph.sizevars.size_hint(dim, fallback=2) != 1 2587 ] 2588 2589 stride = [self.stride[i] for i in non_1_indices] 2590 order = [order[i] for i in non_1_indices] 2591 2592 def sorted_indices(arr): 2593 sorted_arr = sorted(arr) 2594 return [sorted_arr.index(element) for element in arr] 2595 2596 # since we may have removed dimensions, need to re-sort & re-index order 2597 order = sorted_indices(order) 2598 2599 # reorder the stride given order 2600 stride_ordered = [-1] * len(order) 2601 for i in range(len(order)): 2602 stride_ordered[order[i]] = V.graph.sizevars.size_hint(stride[i]) 2603 # check if it is in ascending order 2604 for i in range(len(order) - 1): 2605 if stride_ordered[i] > stride_ordered[i + 1]: 2606 return False 2607 return True 2608 2609 def is_channels_last_stride_ordered(self): 2610 # create channels_last order(NCHW, NCDHW, the C is the first order). 2611 order = [0] + list(reversed(range(1, len(self.stride) - 1))) 2612 order = [len(order)] + order 2613 return self.is_stride_ordered(order) 2614 2615 @staticmethod 2616 def _pad_strides(in_strides, size, dtype): 2617 """ 2618 The padding does not change stride order but makes sure all strides larger 2619 than the threshold are multiple of align. 2620 """ 2621 align = get_align_for_dtype(dtype) 2622 if len(in_strides) == 0: 2623 return in_strides 2624 2625 if not config.pad_channels_last and Layout.is_channels_last_contiguous( 2626 size, in_strides 2627 ): 2628 return in_strides 2629 2630 current_fx_node = V.get_current_node() 2631 if hasattr(current_fx_node, "meta") and current_fx_node.meta.get( 2632 "dislike_padding", False 2633 ): 2634 return in_strides 2635 2636 # get_stride_order does not work with dynamic shape. Also we can not 2637 # statically decide if a padding is needed or how much padding we should 2638 # do for dynamic shape. 2639 # 2640 # Skip padding the strides for dynamic shape for now. 2641 if not all( 2642 isinstance(s, (int, sympy.Integer)) 2643 for s in itertools.chain(in_strides, size) 2644 ): 2645 return in_strides 2646 2647 stride_order = get_stride_order(in_strides) 2648 fill_order = stride_order2fill_order(stride_order) 2649 2650 new_strides = [0 for _ in range(len(in_strides))] 2651 # since we pad when the layout is flexible, we can decide the 2652 # smallest stride to be 1. 2653 new_strides[fill_order[0]] = 1 2654 2655 # Don't align a too small stride since that causes too much memory increase. 2656 # Pad too small stride may also cause perf loss. We may result in many tiny data blocks 2657 # with gaps in between. That causes less coalesced GPU memory access! 2658 # 2659 # Initially we pick 320 as the threshold since for alignement=16, 2660 # that results in at most 5% memory cost. 2661 # 2662 # But later on we raise the threshold to 1024 to avoid interfere with persistent reduction. 2663 # Let's say an inner reduction has a row size 513. Inductor will generate 2664 # persistent reduction code. 2665 # If we do padding, the strides are not contiguous any more. Inductor 2666 # uses a much smaller threshold for persistent reduction in this case and 2667 # generates potentially worse non-persistent reduction code. 2668 # 2669 # This change turns HF AllenaiLongformerBase amp training from a loss of 1.09x to a win of 1.05x. 2670 # (baseline: 71.09ms, padding w/o this change: 77.38ms, padding with this change: 67.77ms) 2671 align_stride_threshold = 1024 2672 padded = False 2673 for rank, idx in enumerate(fill_order[1:], start=1): 2674 prev_idx = fill_order[rank - 1] 2675 stride = new_strides[prev_idx] * size[prev_idx] 2676 2677 if stride > align_stride_threshold and stride % align != 0: 2678 stride = ceildiv(stride, align) * align 2679 padded = True 2680 new_strides[idx] = stride 2681 2682 if not padded: 2683 # Consider a tensor with shape [256, 1, 5, 5] 2684 # Avoid strides like [25, 5, 5, 1] being padded to equivalent strides 2685 # [25, 25, 5, 1]. 2686 return in_strides 2687 2688 metrics.num_comprehensive_padding += 1 2689 return new_strides 2690 2691 def pad_strides(self): 2692 assert isinstance(self, FlexibleLayout) 2693 assert self._stride is not None 2694 self._stride = self._pad_strides(self._stride, self.size, self.dtype) 2695 2696 def should_pad_strides(self): 2697 return config.comprehensive_padding and isinstance(self, FlexibleLayout) 2698 2699 def as_fixed(self): 2700 if isinstance(self, FixedLayout): 2701 return self 2702 2703 if self.should_pad_strides(): 2704 self.pad_strides() 2705 return FixedLayout( 2706 self.device, 2707 self.dtype, 2708 self.size, 2709 self.stride, 2710 self.offset, 2711 ) 2712 2713 def make_indexer(self): 2714 assert ( 2715 FlexibleLayout.allow_indexing 2716 ), f"convert {type(self).__name__} to FixedLayout first" 2717 return self.as_fixed().make_indexer() 2718 2719 def __eq__(self, other) -> bool: 2720 return ( 2721 self.device == other.device 2722 and self.dtype == other.dtype 2723 and self.size == other.size 2724 and self.stride == other.stride 2725 and self.offset == other.offset 2726 ) 2727 2728 def storage_size(self) -> sympy.Expr: 2729 return compute_required_storage_length(self.size, self.stride, self.offset) # type: ignore[arg-type, return-value] 2730 2731 2732class FixedLayout(Layout): 2733 """A Tensor layout we cannot change""" 2734 2735 def __init__( 2736 self, 2737 device: torch.device, 2738 dtype: torch.dtype, 2739 size: Union[List[Expr], List[int]], 2740 stride: Optional[Sequence[Union[Expr, int]]] = None, 2741 offset: Union[Expr, int] = Integer(0), 2742 ): 2743 if stride is None: 2744 stride = FlexibleLayout.contiguous_strides(size) 2745 super().__init__( 2746 device, 2747 dtype, 2748 size, # type: ignore[arg-type] 2749 stride, 2750 offset, # type: ignore[arg-type] 2751 ) 2752 2753 def make_indexer(self): 2754 """A closure containing math to read a given element""" 2755 2756 def indexer(index): 2757 assert len(index) == len(self.stride) 2758 assert len(index) == len(self.size) 2759 result = self.offset 2760 for idx, stride, sz in zip(index, self.stride, self.size): 2761 if sz != 1: 2762 result = result + idx * stride 2763 return result 2764 2765 return indexer 2766 2767 2768class FlexibleLayout(Layout): 2769 """A Tensor layout we are allowed to change""" 2770 2771 allow_indexing = False 2772 2773 # WARNING! This doesn't handle zero size tensors correctly 2774 @staticmethod 2775 def contiguous_strides(sizes): 2776 if len(sizes) == 0: 2777 return [] 2778 reversed_strides = [sympy.Integer(1)] 2779 for size in reversed(sizes[1:]): 2780 reversed_strides.append(size * reversed_strides[-1]) 2781 return list(reversed(reversed_strides)) 2782 2783 @staticmethod 2784 def fill_ordered(sizes, order): 2785 """ 2786 Create a stride based on the order the dimensions should be filled in. 2787 2788 In this format, channels last would be: 2789 [1, 3, 2, 0] 2790 """ 2791 assert set(range(len(sizes))) == set(order) 2792 next_stride = sympy.Integer(1) 2793 strides = [None] * len(order) 2794 2795 for i in order: 2796 strides[i] = next_stride 2797 next_stride = next_stride * sizes[i] 2798 return strides 2799 2800 @staticmethod 2801 def stride_ordered(sizes, order): 2802 """ 2803 Create a stride based on the sorted order of a permuted range. 2804 2805 In this format, channels last would be: 2806 [3, 0, 2, 1] 2807 """ 2808 assert set(range(len(sizes))) == set(order) 2809 fill_order = stride_order2fill_order(order) 2810 return FlexibleLayout.fill_ordered(sizes, fill_order) 2811 2812 @staticmethod 2813 def stride_ordered_for_memory_format(sizes, memory_format): 2814 """ 2815 Create a stride based on a memory format. 2816 2817 Memory format is translasted into a stride order, 2818 so channels_last is the same as: 2819 FlexibleLayout.stride_ordered(sizes, [3, 0, 2, 1]) 2820 2821 This interface does not support memory_format `torch.preserve_format` 2822 which should be used to deduce a format from another source 2823 """ 2824 if memory_format == torch.channels_last: 2825 return FlexibleLayout.stride_ordered(sizes, NHWC_STRIDE_ORDER) 2826 elif memory_format == torch.channels_last_3d: 2827 return FlexibleLayout.stride_ordered(sizes, NHWDC_STRIDE_ORDER) 2828 elif memory_format == torch.contiguous_format: 2829 return FlexibleLayout.contiguous_strides(sizes) 2830 else: 2831 log.debug( 2832 "stride_ordered_for_memory_format, unsuppored memory_format: %s", 2833 memory_format, 2834 ) 2835 raise NotImplementedError 2836 2837 @staticmethod 2838 def same_ordered(sizes, stride): 2839 """ 2840 Create a stride that has the same stride order as given stride 2841 2842 For example, if given stride is [1000, 1, 100, 10], 2843 the fill order should be [1, 3, 2, 0] 2844 """ 2845 assert len(sizes) == len(stride) 2846 stride = [V.graph.sizevars.size_hint(x) for x in stride] 2847 fill_order = sorted(range(len(stride)), key=stride.__getitem__) 2848 return FlexibleLayout.fill_ordered(sizes, fill_order) 2849 2850 def as_stride_order(self, order, allow_padding=False): 2851 new_stride = self.stride_ordered(self.size, order) 2852 if self.should_pad_strides() and allow_padding: 2853 new_stride = self._pad_strides(new_stride, self.size, self.dtype) 2854 2855 return FixedLayout( 2856 self.device, 2857 self.dtype, 2858 self.size, 2859 new_stride, 2860 self.offset, 2861 ) 2862 2863 def as_fill_order(self, order): 2864 new_stride = self.fill_ordered(self.size, order) 2865 if self.should_pad_strides(): 2866 new_stride = self._pad_strides(new_stride, self.size, self.dtype) 2867 return FixedLayout( 2868 self.device, 2869 self.dtype, 2870 self.size, 2871 new_stride, 2872 self.offset, 2873 ) 2874 2875 def as_same_order(self, stride): 2876 new_stride = self.same_ordered(self.size, stride) 2877 if self.should_pad_strides(): 2878 new_stride = self._pad_strides(new_stride, self.size, self.dtype) 2879 return FixedLayout( 2880 self.device, 2881 self.dtype, 2882 self.size, 2883 new_stride, 2884 self.offset, 2885 ) 2886 2887 def __init__(self, device, dtype, size, stride_order=None): 2888 if stride_order: 2889 strides = FlexibleLayout.fill_ordered(size, stride_order) 2890 else: 2891 strides = FlexibleLayout.contiguous_strides(size) 2892 super().__init__(device, dtype, size, strides) 2893 2894 2895class NonOwningLayout(Layout): 2896 """Is a view into the storage of another tensor""" 2897 2898 def __init__(self, view: Union[BaseView, "TensorBox"]): 2899 layout = view.get_layout() 2900 super().__init__( 2901 layout.device, 2902 layout.dtype, 2903 layout.size, 2904 layout.stride, 2905 ) 2906 self.view = view 2907 2908 def make_indexer(self): 2909 return self.as_fixed().make_indexer() 2910 2911 def maybe_guard_aligned(self): 2912 offset = self.view.get_layout().offset 2913 if offset == 0: 2914 return True 2915 from .compile_fx import ALIGNMENT 2916 2917 return V.graph.sizevars.statically_known_multiple_of(offset, ALIGNMENT) # type: ignore[arg-type] 2918 2919 2920class NoneLayout(IRNode): 2921 # This is janky, I figured out what fields to populate by just running 2922 # the model I was interested in and adding properties/methods as needed. 2923 # This doesn't inherit from Layout because Layout assumes you have stuff 2924 # like sizes, but I don't really have anything here. 2925 # 2926 # If you have an ir.Node with NoneLayout, you probably need to setup 2927 # dependencies manually in scheduler 2928 2929 def __init__(self, device): 2930 self.device = device 2931 self.size = [0] 2932 self.stride = [0] 2933 2934 def storage_size(self): 2935 return 0 2936 2937 def as_fixed(self): 2938 return self 2939 2940 2941class MutationLayoutSHOULDREMOVE(Layout): 2942 def __init__(self, target: IRNode): 2943 super().__init__( 2944 target.get_device(), 2945 target.get_dtype(), 2946 target.get_size(), 2947 None, 2948 ) 2949 self.target = target 2950 name = self.get_buffer().get_name() 2951 V.graph.mark_buffer_mutated(name) 2952 2953 @Layout.stride.getter # type: ignore[attr-defined] 2954 def stride(self): 2955 return self.real_layout().stride 2956 2957 def storage_size(self) -> sympy.Expr: 2958 return self.real_layout().storage_size() 2959 2960 def get_buffer(self) -> "Buffer": 2961 def unwrap_views(target): 2962 if isinstance(target, MutationLayoutSHOULDREMOVE): 2963 return unwrap_views(target.target) 2964 if isinstance(target, BaseView): 2965 return unwrap_views(target.unwrap_view()) 2966 if isinstance(target, MutableBox): 2967 return unwrap_views(target.data) 2968 return target 2969 2970 result = unwrap_views(self.target) 2971 assert isinstance( 2972 result, Buffer 2973 ), "MutationLayoutSHOULDREMOVE must refer to a buffer" 2974 return result 2975 2976 def real_layout(self): 2977 return self.get_buffer().layout 2978 2979 @classmethod 2980 def realize_into(cls, src, dst, unsafe_alias=False): 2981 dst.realize() 2982 # NOTE: We must realize users of `dst` before we realize `src`, since 2983 # realization order determines scheduling order. Otherwise, src's 2984 # mutation would be scheduled before the existing users of dst! 2985 V.graph.mark_buffer_mutated(dst.get_name()) 2986 2987 if isinstance(src, TensorBox): 2988 src = src.data 2989 2990 # We copy the contents of src into dst. In most cases this should 2991 # be fused into a single kernel by the scheduler. 2992 # NOTE: We cannot change src's layout to mutate dst directly as this 2993 # would alias src to dst, which is not correct as further mutations to 2994 # dst would effect users of src. However if there are no more users of 2995 # dst, we can alias src to dst. 2996 src.realize_hint() 2997 2998 if not unsafe_alias: 2999 src = Pointwise.create( 3000 device=src.get_device(), 3001 dtype=src.get_dtype(), 3002 inner_fn=src.make_loader(), 3003 ranges=[ 3004 V.graph.sizevars.guard_equals(a, b) 3005 for a, b in zip(src.get_size(), dst.get_size()) 3006 ], 3007 ).data 3008 3009 src.realize() 3010 assert isinstance(src.data.layout, FlexibleLayout) 3011 src.data.layout = MutationLayoutSHOULDREMOVE(dst) 3012 return src.data 3013 3014 def as_fixed(self): 3015 return self 3016 3017 def make_indexer(self): 3018 return self.target.make_indexer() 3019 3020 3021@dataclasses.dataclass 3022class Buffer(IRNode): 3023 # Name is sometimes None; e.g., ForceInPlace, where there isn't 3024 # a meaningful name 3025 name: Optional[str] 3026 layout: Layout 3027 3028 # Multi-output buffers will define 'outputs: List[Buffer]'. Confusingly, 3029 # MultiOutput does NOT define this! 3030 3031 def __post_init__(self): 3032 super().__post_init__() 3033 self.origin_node = None 3034 3035 def make_indexer(self): 3036 return self.layout.make_indexer() 3037 3038 def get_name(self) -> str: 3039 assert self.name, self 3040 return self.name 3041 3042 def get_device(self): 3043 return self.layout.device 3044 3045 def get_origin_node(self): 3046 return self.origin_node 3047 3048 @property 3049 def dtype(self): 3050 return getattr(self.layout, "dtype", None) 3051 3052 def get_size(self): 3053 return list(self.layout.size) 3054 3055 def get_stride(self): 3056 return list(self.layout.stride) 3057 3058 def get_offset(self): 3059 return self.layout.offset 3060 3061 def get_layout(self): 3062 return self.layout 3063 3064 def get_storage_numel(self): 3065 return self.get_numel() 3066 3067 def is_extern(self): 3068 return False 3069 3070 def freeze_layout(self): 3071 if not isinstance(self.layout, (MultiOutputLayout, NonOwningLayout)): 3072 self.layout = self.layout.as_fixed() 3073 3074 def freeze_layout_with_stride_order(self, order, allow_padding=False): 3075 assert isinstance(self.layout, FlexibleLayout) 3076 self.layout = self.layout.as_stride_order(order, allow_padding=allow_padding) 3077 3078 def freeze_layout_with_fill_order(self, order): 3079 assert isinstance(self.layout, FlexibleLayout) 3080 self.layout = self.layout.as_fill_order(order) 3081 3082 def freeze_layout_with_same_order(self, stride): 3083 assert isinstance(self.layout, FlexibleLayout) 3084 self.layout = self.layout.as_same_order(stride) 3085 3086 def is_zero_elements(self): 3087 return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0)) # type: ignore[arg-type] 3088 3089 def make_loader(self): 3090 # Loading from a zero-element buffer is a no-op 3091 if self.is_zero_elements(): 3092 return partial(nop_loader_fn, dtype=self.get_dtype()) 3093 3094 def loader(index): 3095 indexer = self.layout.make_indexer() 3096 return ops.load(self.name, indexer(index)) 3097 3098 return loader 3099 3100 def is_no_op(self): 3101 return False 3102 3103 def codegen_reference(self, writer=None): 3104 return self.get_name() 3105 3106 def decide_layout(self): 3107 pass 3108 3109 def get_inputs_that_alias_output(self): 3110 if isinstance(self.layout, NonOwningLayout): 3111 return [self.layout.view.get_name()] 3112 return () 3113 3114 def get_mutation_names(self): 3115 if isinstance(self.layout, MutationLayoutSHOULDREMOVE): 3116 return [self.layout.target.get_name()] 3117 return () 3118 3119 def get_read_writes(self): 3120 with patch.object(FlexibleLayout, "allow_indexing", True): 3121 return extract_read_writes( 3122 self.make_loader(), 3123 self.get_size(), 3124 ) 3125 3126 def get_reads(self): 3127 return self.get_read_writes().reads 3128 3129 def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: 3130 return set() 3131 3132 def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: 3133 """ 3134 Returns the unbacked symbols which are required to be in scope in 3135 order to successfully perform codegen for this buffer. For example, 3136 a buffer that corresponds to an extern kernel call that takes i0 as 3137 an argument would return {i0} here. This is used to generate necessary 3138 dependencies that ensure we actually bind i0 in codegen before you 3139 try to use it. 3140 3141 Note that this is NOT transitive; in particular, if this buffer takes 3142 in as input another buffer with dynamic shape (e.g., (i0,)), we will 3143 not report it here, because you will already have a dependency 3144 on that buffer, which will eventually have a dependency on i0 if 3145 necessary. 3146 """ 3147 return set() 3148 3149 def realize(self): 3150 pass 3151 3152 def get_workspace_size(self): 3153 """ 3154 Gets extra global memory size needed by this buffer. 3155 Some algorithms (e.g. group gemm) may require extra global memory in the generated code. 3156 """ 3157 return 0 3158 3159 def should_allocate(self): 3160 # Returns False by default. 3161 return False 3162 3163 3164class InputBuffer(Buffer): 3165 pass 3166 3167 3168class ConstantBuffer(InputBuffer): 3169 override_device: Optional[torch.device] = None 3170 3171 def make_loader(self): 3172 def loader(index): 3173 indexer = self.layout.make_indexer() 3174 return ops.load( 3175 V.graph.constant_name(self.get_name(), self.override_device), 3176 indexer(index), 3177 ) 3178 3179 return loader 3180 3181 def constant_to_device(self, device): 3182 return ConstantBuffer( 3183 V.graph.constant_name(self.get_name(), device), self.layout 3184 ) 3185 3186 3187class NoneAsConstantBuffer(IRNode): 3188 def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: 3189 return set() 3190 3191 def codegen_reference(self, writer=None): 3192 return V.graph.wrapper_code.none_str 3193 3194 3195class ShapeAsConstantBuffer(IRNode): 3196 def __init__(self, shape): 3197 super().__init__() 3198 self.shape = shape 3199 3200 def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: 3201 return free_unbacked_symbols(self.shape) 3202 3203 def codegen_reference(self, writer=None): 3204 return V.graph.wrapper_code.expr_printer(V.graph.sizevars.simplify(self.shape)) 3205 3206 3207@dataclasses.dataclass 3208class ComputedBuffer(Buffer): 3209 data: Loops 3210 3211 def get_computed_buffer_name(self): 3212 """ 3213 Returns self.name if it exists, otherwise returns the name of the data node if that exists. 3214 If neither exist, returns None. 3215 """ 3216 if self.name is not None: 3217 return self.name 3218 if hasattr(self.data, "name"): 3219 return self.data.name 3220 return None 3221 3222 @cache_on_self 3223 def num_reads(self): 3224 return len(self.get_read_writes().reads) 3225 3226 def get_read_writes(self): 3227 with patch.object(FlexibleLayout, "allow_indexing", True): 3228 if self.data.get_reduction_type(): 3229 return extract_read_writes( 3230 self.get_store_function(), 3231 self.data.get_pointwise_size(), 3232 self.data.get_reduction_size(), 3233 ) 3234 else: 3235 return extract_read_writes( 3236 self.get_store_function(), 3237 self.data.get_size(), 3238 ) 3239 3240 def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: 3241 # Ordinarily, we'd like to just peek at the arguments list, 3242 # but ComputedBuffers have no argument list. 3243 # 3244 # Morally, this logic needs to be synchronized with the 3245 # KernelArgs.size calls, which are responsible for making symbols make 3246 # there way as kernel arguments (and it is precisely passing in one of 3247 # those symbols that establishes a dependency). However, we haven't 3248 # started codegen yet so we can't directly reuse that logic. 3249 # 3250 # For now, I'm just yoloing with the size of the buffer. Not sure if 3251 # it is enough. 3252 # 3253 # One thing you might wonder is if this is enough for a ComputedBuffer 3254 # denoting a reduction over i0. Empirically, it is enough, but for an 3255 # unusual reason: we only need accurate dependencies for item() call, 3256 # but it's impossible to end up with a reduction over i0 from an 3257 # item() call without a regular non-reduction buffer first. 3258 return ( 3259 free_unbacked_symbols(self.get_size()) 3260 | free_unbacked_symbols(self.get_stride()) 3261 | free_unbacked_symbols(self.get_offset()) 3262 | self.data.get_unbacked_symbol_uses() 3263 ) 3264 3265 def make_loader(self): 3266 # Inline constants and index_expressions 3267 if ( 3268 hasattr(self.data, "make_loader") 3269 and self.name not in V.graph.mutated_buffers 3270 and self.num_reads() == 0 3271 ): 3272 # can be inlined 3273 return self.data.make_loader() 3274 return super().make_loader() 3275 3276 def get_store_function(self): 3277 indexer = self.layout.as_fixed().make_indexer() 3278 if isinstance(self.data, (Reduction, Scan)): 3279 return partial(self.data.store_reduction, self.name, indexer) 3280 else: 3281 assert isinstance(self.data, Pointwise) 3282 return partial(self.data.store_output, self.name, indexer) 3283 3284 def get_fill_order(self): 3285 """ 3286 If our layout is still flexible, try to determine the stride order based on stride orders of reads. 3287 3288 TODO(jansel): A better algorithm here would look at downstream consumers of this 3289 value and try to do global graph-level layout optimization. 3290 This is also something just begging to be autotuned. 3291 """ 3292 if isinstance(self.layout, FlexibleLayout): 3293 (index_vars, reduction_vars), _ = dependencies.index_vars_squeeze( 3294 self.data.get_pointwise_size(), self.data.get_reduction_size() 3295 ) 3296 reads = self.get_read_writes().reads 3297 reads_bufs = [ 3298 V.graph.name_to_buffer[r.name] 3299 if r.name in V.graph.name_to_buffer.keys() 3300 else None 3301 for r in reads 3302 ] 3303 # only consider reads to buffer of same size 3304 # ignore StarDeps because they don't contribute stride information 3305 assert all( 3306 isinstance(r, (dependencies.StarDep, dependencies.MemoryDep)) 3307 for r in reads 3308 ) 3309 reads = [ 3310 sympy_subs( 3311 r.index, {v: sympy.Integer(0) for v in reduction_vars if v != 0} 3312 ) 3313 for r in reads 3314 if isinstance(r, dependencies.MemoryDep) 3315 ] 3316 3317 if reads: 3318 if isinstance(self.data, Scan): 3319 indices = self.data.reindex(index_vars, reduction_vars) 3320 else: 3321 indices = index_vars 3322 stride_lengths = [ 3323 V.graph.sizevars.stride_hints(expr, indices) for expr in reads # type: ignore[arg-type] 3324 ] 3325 from .scheduler import pick_loop_order 3326 3327 return pick_loop_order(stride_lengths, self.get_size()) 3328 3329 return None 3330 3331 def decide_layout(self): 3332 if isinstance(self.layout, FlexibleLayout): 3333 order = self.get_fill_order() 3334 if order: 3335 self.freeze_layout_with_fill_order(order) 3336 else: 3337 self.freeze_layout() 3338 3339 @cache_on_self 3340 def get_default_sizes_body(self): 3341 args, var_ranges = dependencies.index_vars_squeeze( 3342 self.data.get_pointwise_size(), self.data.get_reduction_size(), prefix="q" 3343 ) 3344 with patch.object(ConstantBuffer, "override_device", self.get_device()): 3345 body = LoopBody( 3346 self.get_store_function(), 3347 (args if self.get_reduction_type() else args[:1]), 3348 var_ranges, 3349 ) 3350 index_vars = [] 3351 reduce_vars: List[Any] = [] 3352 index_size = [] 3353 reduce_size = [] 3354 for v, s in var_ranges.items(): 3355 if v in args[0]: 3356 assert not reduce_vars 3357 index_vars.append(v) 3358 index_size.append(s) 3359 else: 3360 assert v in args[1] 3361 reduce_vars.append(v) 3362 reduce_size.append(s) 3363 return (index_size, reduce_size), body, (index_vars, reduce_vars) 3364 3365 def simplify_and_reorder( 3366 self, 3367 extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None, 3368 ): 3369 """ 3370 This is a main place where we do loop transformations in a 3371 backend-agnostic way. 3372 3373 Here we: 3374 1) Remove any 1 dimensions 3375 2) Fuse contiguous dimensions together 3376 3) Reorder dimensions based on stride orders 3377 3378 Optional argument extra_indexing_constraints can be used to append additional 3379 indexing expressions to existing ones derived from buffer's body. This can be useful 3380 to fuse scheduler nodes with compatible ranges, e.g. (s0*s1*...,) and (s0, s1, s2, ...) 3381 on CPU by preventing indexing simplifications and obtaining index/reduce ranges for 3382 the scheduler node compatible with other nodes. 3383 """ 3384 ( 3385 (index_size, reduce_size), 3386 body, 3387 (index_vars, reduce_vars), 3388 ) = self.get_default_sizes_body() 3389 3390 index_formulas = [*body.indexing_exprs.values()] 3391 if extra_indexing_constraints is not None: 3392 assert ( 3393 isinstance(extra_indexing_constraints, tuple) 3394 and len(extra_indexing_constraints) == 2 3395 ) 3396 extra_indexing_ranges, extra_indexing_expr = extra_indexing_constraints 3397 assert isinstance(extra_indexing_ranges, dict) 3398 assert isinstance(extra_indexing_expr, list) 3399 assert all(isinstance(f, Expr) for f in extra_indexing_expr) 3400 3401 expected_var_ranges = body.var_ranges 3402 assert expected_var_ranges == extra_indexing_ranges, ( 3403 expected_var_ranges, 3404 extra_indexing_ranges, 3405 ) 3406 # remove already existing expressions 3407 extra_indexing_expr = [ 3408 e for e in extra_indexing_expr if e not in index_formulas 3409 ] 3410 index_formulas += extra_indexing_expr 3411 3412 reads_bufs = [ 3413 V.graph.name_to_buffer[reads_name] 3414 if reads_name in V.graph.name_to_buffer.keys() 3415 else None 3416 for reads_name in body.reads_name2expr.keys() 3417 ] 3418 memory_addrs = [ 3419 *body.reads_name2expr.values(), 3420 *body.writes_name2expr.values(), 3421 ] 3422 3423 def simplify_and_reorder(x_vars, support_vars, sizes): 3424 sizes, reindex0, reindex1 = self._apply_loop_reordering( 3425 x_vars, support_vars, sizes, memory_addrs 3426 ) 3427 # for NHWC: reindex0([0,1,2,3]) = [0,2,3,1], reindex1([0,1,2,3]) = [0,3,2,1] 3428 x_vars = reindex0(x_vars) 3429 sizes, reindex2, prune = V.graph.sizevars._simplify_loops( 3430 x_vars, 3431 sizes, 3432 index_prevent_reordering(index_formulas, x_vars, sizes), 3433 ) 3434 x_vars = prune(x_vars) 3435 # sizes, reindex1, prune = _simplify_loops(x_vars, sizes, index_formulas) 3436 # x_vars = prune(x_vars) 3437 # sizes, reindex2 = self._apply_loop_reordering(x_vars, sizes, memory_addrs) 3438 reindex = fuse_reindexing(reindex1, reindex2) 3439 return sizes, reindex, reindex1 3440 3441 support_vars = index_vars + reduce_vars 3442 iter_ranges, iter_reindex, _ = simplify_and_reorder( 3443 index_vars, 3444 support_vars, 3445 index_size, 3446 ) 3447 reduce_ranges, reduce_reindex, _ = simplify_and_reorder( 3448 reduce_vars, support_vars, reduce_size 3449 ) 3450 3451 # retrace the loop body with simplification and reordering applied 3452 (iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze( 3453 iter_ranges, reduce_ranges, prefix="z" 3454 ) 3455 body = LoopBody( 3456 body, [iter_reindex(iter_vars), reduce_reindex(reduce_vars)], var_ranges 3457 ) 3458 return (iter_ranges, reduce_ranges), body 3459 3460 @staticmethod 3461 def _apply_loop_reordering( 3462 index_vars, 3463 support_vars, 3464 sizes, 3465 memory_addrs, 3466 priority_idx=None, 3467 ): 3468 """ 3469 Shuffle the order of loops around to hopefully improve performance. 3470 """ 3471 from .scheduler import pick_loop_order 3472 3473 if priority_idx is None: 3474 priority_idx = [] 3475 3476 try: 3477 strides = [ 3478 V.graph.sizevars.stride_hints(expr, index_vars, support_vars) 3479 for expr in memory_addrs 3480 ] 3481 assert len(strides) == len(memory_addrs) and len(strides[0]) == len( 3482 index_vars 3483 ) 3484 order = list(reversed(pick_loop_order(strides, sizes, priority_idx))) 3485 except Exception: 3486 if config.debug: 3487 log.warning( 3488 "Did not simplify complex index:\n%s\n%s", 3489 dict(zip(index_vars, sizes)), 3490 memory_addrs, 3491 ) 3492 order = list(range(len(sizes))) 3493 sizes = [sizes[i] for i in order] 3494 return sizes, same_reorder(order), inverse_reorder(order) 3495 3496 def get_reduction_size(self): 3497 return self.data.get_reduction_size() 3498 3499 def get_reduction_type(self): 3500 return self.data.get_reduction_type() 3501 3502 def is_no_op(self): 3503 return self.data.is_zero_elements() 3504 3505 def should_allocate(self): 3506 return True 3507 3508 def constant_to_device(self, device): 3509 """Move this to a given device. Requires that all reads are to constants.""" 3510 return self.data.constant_to_device(device) 3511 3512 3513class TemplateBuffer(Buffer): 3514 """ 3515 Represents a Triton (in the future other type) of template operator 3516 that we can fuse an epilogue onto. 3517 """ 3518 3519 def __init__(self, layout, inputs, make_kernel_render): 3520 super().__init__(name=None, layout=layout) 3521 self.inputs = InputsKernel.unwrap_storage(inputs) 3522 self.make_kernel_render = make_kernel_render 3523 self.name = V.graph.register_buffer(self) 3524 3525 def get_read_writes(self): 3526 return self.normalized_read_writes() 3527 3528 def normalized_read_writes(self): 3529 name = self.get_name() 3530 indexer = self.layout.make_indexer() 3531 3532 def dummy(index, rindex): 3533 assert len(rindex) == 0 3534 return ops.store(name, indexer(index), "fake") 3535 3536 deps = dependencies.extract_read_writes( 3537 dummy, self.get_size(), (), normalize=True 3538 ) 3539 deps.reads = {dependencies.StarDep(x.get_name()) for x in self.inputs} 3540 return deps 3541 3542 def get_reduction_size(self): 3543 return 1 3544 3545 def get_reduction_type(self): 3546 return None 3547 3548 def is_no_op(self): 3549 return False 3550 3551 def should_allocate(self): 3552 return True 3553 3554 def simplify_and_reorder( 3555 self, 3556 extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None, 3557 ): 3558 return ( 3559 ( 3560 self.get_size(), 3561 (), 3562 ), 3563 None, 3564 ) 3565 3566 3567class TritonTemplateBuffer(TemplateBuffer): 3568 def __init__( 3569 self, 3570 layout, 3571 inputs, 3572 make_kernel_render, 3573 debug_extra=None, 3574 mutated_inputs: Optional[Iterable[IRNode]] = None, 3575 ): 3576 """ 3577 NOTE:[TritonTemplates with multiple outputs] 3578 We want the ability for TritonTemplates to output multiple tensors. Triton 3579 kernels have no notion of outputs and this is done by creating tensors that 3580 are then mutated by the kernel. Currenlty our STORE_OUTPUT codegen doesn't 3581 support creating multinode outputs for triton templates. 3582 We work around this by creating an extra input buffer during the lowering 3583 and we mark them as mutated inputs. 3584 """ 3585 super().__init__(layout, inputs, make_kernel_render) 3586 self.debug_extra = debug_extra 3587 self.mutated_inputs = mutated_inputs 3588 if mutated_inputs is not None: 3589 # Ensure that the mutated inputs are only allowed for certain nodes 3590 allowed_set = { 3591 torch.ops.higher_order.flex_attention, 3592 torch.ops.higher_order.flex_attention_backward, 3593 } 3594 current_node = V.graph.current_node.target 3595 assert ( 3596 current_node in allowed_set 3597 ), f"Mutated inputs are only allowed for {allowed_set} but got {current_node}" 3598 mark_node_as_mutating(self, *mutated_inputs) 3599 3600 def __str__(self): 3601 out = f"TritonTemplateBuffer(layout={self.layout}, {self.debug_extra})" 3602 return out 3603 3604 3605PrimitiveInfoType = Union[int, float, bool, str, List[Union[int, str, float, bool]]] 3606 3607 3608class ChoiceCaller: 3609 """ 3610 Represents a possible choice used in autotune_process.py. 3611 During autotuning, self.benchmark() is first called to get benchmark result, 3612 and if this choice is selected, self.output_node() is called to get the output_node. 3613 3614 Children classes: TritonTemplateCaller, CUDATemplateCaller. 3615 """ 3616 3617 def __init__(self, name, input_nodes, layout): 3618 super().__init__() 3619 self.name = name 3620 self.layout = layout 3621 self.input_nodes = input_nodes 3622 3623 def benchmark(self, *args, out) -> float: 3624 algo = self.to_callable() 3625 return do_bench(algo, args, {"out": out}) 3626 3627 def call_name(self) -> str: 3628 raise NotImplementedError 3629 3630 def to_callable(self): 3631 raise NotImplementedError 3632 3633 def hash_key(self) -> str: 3634 raise NotImplementedError 3635 3636 def output_node(self) -> "TensorBox": 3637 raise NotImplementedError 3638 3639 def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]: 3640 """Information returned here is logged to the autotune log file when that is enabled.""" 3641 return {} 3642 3643 3644class TritonTemplateCallerBase(ChoiceCaller): 3645 def get_make_kernel_render(self) -> Any: 3646 raise NotImplementedError 3647 3648 3649class MultiTemplateBuffer(TritonTemplateBuffer): 3650 """ 3651 Represents a Buffer with multiple backing implementation choices. 3652 3653 Choices can be TritonTemplates or ExternKernels. During scheduling if there is a potential 3654 epilogue we will benchmark each of the choices with the epilogue to determine an implementation. 3655 Otherwise, the fastest base choice will be chosen. 3656 """ 3657 3658 def __init__( 3659 self, 3660 layout: Layout, 3661 inputs: List[IRNode], 3662 choice_timings: Callable[[], Dict[ChoiceCaller, float]], 3663 ): 3664 super().__init__(layout=layout, inputs=inputs, make_kernel_render=None) 3665 self._choice_timings_fn = choice_timings 3666 self._choice_timings: Optional[Dict[ChoiceCaller, float]] = None 3667 self.original_inputs = inputs 3668 3669 @property 3670 def choice_timings(self) -> Dict[ChoiceCaller, float]: 3671 if self._choice_timings is None: 3672 self._choice_timings = self._choice_timings_fn() 3673 return self._choice_timings 3674 3675 @contextlib.contextmanager 3676 def swap_as_triton_caller(self, caller: TritonTemplateCallerBase): 3677 assert isinstance(caller, torch._inductor.select_algorithm.TritonTemplateCaller) 3678 assert self.layout == caller.layout 3679 3680 render = self.make_kernel_render 3681 self.make_kernel_render = caller.get_make_kernel_render() 3682 try: 3683 yield 3684 finally: 3685 self.make_kernel_render = render 3686 3687 def finalize_as_triton_caller(self, caller: TritonTemplateCallerBase): 3688 assert isinstance(caller, torch._inductor.select_algorithm.TritonTemplateCaller) 3689 assert self.layout.size == caller.layout.size 3690 assert self.layout.stride == caller.layout.stride 3691 self.make_kernel_render = caller.get_make_kernel_render() 3692 3693 def get_min_choice(self) -> Tuple[ChoiceCaller, float]: 3694 min_choice = min(self.choice_timings, key=self.choice_timings.get) # type: ignore[arg-type] 3695 return (min_choice, self.choice_timings[min_choice]) 3696 3697 3698class CUDATemplateBuffer(TemplateBuffer): 3699 def __init__( 3700 self, 3701 layout, 3702 inputs, 3703 make_kernel_render, 3704 workspace_size: int, 3705 template: "CUDATemplate", # type: ignore[name-defined] # noqa: F821 3706 ): 3707 super().__init__(layout, inputs, make_kernel_render) 3708 # Global memory (in bytes) needed for this template. 3709 self.workspace_size = workspace_size 3710 self.template = template 3711 3712 def get_workspace_size(self): 3713 return self.workspace_size if self.workspace_size is not None else 0 3714 3715 3716class CppTemplateBuffer(TemplateBuffer): 3717 def __init__(self, layout, inputs, make_kernel_render, template, choice): 3718 super().__init__(layout, inputs, make_kernel_render) 3719 self.template = template 3720 self.choice = choice 3721 3722 3723@dataclasses.dataclass 3724class InputsKernel(Buffer): 3725 inputs: List[Buffer] 3726 3727 def get_read_writes_input(self, x): 3728 return dependencies.StarDep(x.get_name()) 3729 3730 def get_read_writes(self): 3731 star_dep = [] 3732 for input in self.inputs: 3733 if isinstance(input, list): 3734 star_dep.extend([self.get_read_writes_input(x) for x in input]) 3735 else: 3736 star_dep.append(self.get_read_writes_input(input)) 3737 3738 return dependencies.ReadWrites( 3739 set(star_dep), 3740 {dependencies.StarDep(self.get_name())}, 3741 set(), 3742 [], 3743 None, 3744 op_counts=collections.Counter(), 3745 ) 3746 3747 @classmethod 3748 def unwrap_storage_for_input(cls, x): 3749 if isinstance(x, TensorBox): 3750 x = x.data 3751 if isinstance(x, StorageBox): 3752 x = x.data 3753 if isinstance(x, BaseView) and not isinstance(x, ReinterpretView): 3754 x = ExternKernel.realize_input(x) 3755 if isinstance(x, TensorBox): 3756 # when converting to ReinterpretView fails in the 3757 # realize_input call above, the result will be wrapped 3758 # into TensorBox / StorageBox pair as a result of the 3759 # cls.copy_input call; so we should unwrap recursively 3760 return cls.unwrap_storage_for_input(x) 3761 if isinstance(x, TorchBindObject): 3762 return x 3763 assert isinstance(x, (Buffer, ReinterpretView)), x 3764 return x 3765 3766 @staticmethod 3767 def unwrap_storage(inputs): 3768 inputs_new = [] 3769 for x in inputs: 3770 if isinstance(x, list): 3771 x = [InputsKernel.unwrap_storage_for_input(i) for i in x] 3772 else: 3773 x = InputsKernel.unwrap_storage_for_input(x) 3774 inputs_new.append(x) 3775 return inputs_new 3776 3777 def is_extern(self): 3778 return True 3779 3780 3781class NopKernel(InputsKernel): 3782 def is_no_op(self): 3783 return True 3784 3785 3786class ConcatKernel(NopKernel): 3787 """ 3788 There isn't actually a real kernel for concat, we just change the 3789 storage for the upstream data. 3790 """ 3791 3792 @classmethod 3793 def create(cls, inputs, dim): 3794 device = inputs[0].get_device() 3795 dtype = inputs[0].get_dtype() 3796 new_size = list(inputs[0].get_size()) 3797 offsets_start = [0] 3798 offsets_end = [new_size[dim]] 3799 assert 0 <= dim < len(new_size) 3800 for i in range(1, len(inputs)): 3801 input_size = inputs[i].get_size() 3802 offsets_start.append(new_size[dim]) 3803 assert len(input_size) == len(new_size) 3804 assert inputs[i].get_dtype() == dtype 3805 assert inputs[i].get_device() == device 3806 for j in range(len(new_size)): 3807 if j == dim: 3808 new_size[j] = new_size[j] + input_size[j] 3809 else: 3810 new_size[j] = V.graph.sizevars.guard_equals( 3811 new_size[j], input_size[j] 3812 ) 3813 offsets_end.append(new_size[dim]) 3814 3815 output_stride = FlexibleLayout.contiguous_strides(new_size) 3816 # If any of the inputs is in CL format, use CL format for the output 3817 for i in range(len(inputs)): 3818 x = inputs[i] 3819 if is_storage_and_layout(x): 3820 layout = x.get_layout() 3821 if isinstance( 3822 layout, FixedLayout 3823 ) and Layout.is_channels_last_contiguous(layout.size, layout.stride): 3824 # use CL stride for the output 3825 output_stride = make_channels_last_strides_for(new_size) 3826 break 3827 any_input_is_storage_and_layout = any(is_storage_and_layout(x) for x in inputs) 3828 fx_node_args = V.graph.current_node.args[0] 3829 assert isinstance(fx_node_args, list) 3830 # If any of the inputs has meta tensor and the meta tensor is in CL format, use CL format for the output 3831 if any_input_is_storage_and_layout is False and any( 3832 "val" in arg.meta 3833 and ( 3834 arg.meta["val"].is_contiguous(memory_format=torch.channels_last) 3835 or arg.meta["val"].is_contiguous(memory_format=torch.channels_last_3d) 3836 ) 3837 for arg in fx_node_args 3838 ): 3839 output_stride = make_channels_last_strides_for(new_size) 3840 3841 concat_kernel = ConcatKernel( 3842 name=None, 3843 layout=FixedLayout( 3844 device=device, 3845 dtype=dtype, 3846 size=new_size, 3847 stride=output_stride, 3848 ), 3849 inputs=[], 3850 ) 3851 kernel = StorageBox(concat_kernel) 3852 buffer_names = [] 3853 for i in range(len(inputs)): 3854 input_buffer = cls.realize_into( 3855 inputs[i], 3856 SliceView.create( 3857 kernel, dim, offsets_start[i], offsets_end[i], clamp=False 3858 ), 3859 ) 3860 concat_kernel.inputs.append(input_buffer) 3861 3862 if isinstance(inputs[i].data, BaseView): 3863 input_unwrapped = inputs[i].data.unwrap_view() 3864 else: 3865 input_unwrapped = inputs[i].data 3866 3867 if ( 3868 input_unwrapped.is_input_buffer() 3869 and is_gpu(inputs[i].get_device().type) 3870 and not is_dynamic(input_buffer) 3871 ): 3872 buffer_names.append(input_buffer.get_name()) 3873 3874 if len(buffer_names) > 1: 3875 V.graph.register_list(buffer_names) 3876 3877 concat_kernel.name = V.graph.register_buffer(concat_kernel) 3878 concat_kernel.inputs = cls.unwrap_storage(concat_kernel.inputs) 3879 3880 return kernel 3881 3882 @classmethod 3883 def can_realize_into_without_copy(cls, src): 3884 if isinstance(src, TensorBox): 3885 # unwrap a TensorBox 3886 return cls.can_realize_into_without_copy(src.data) 3887 3888 return isinstance(src.data.layout, FlexibleLayout) and not isinstance( 3889 src.data, ExternKernelAlloc 3890 ) 3891 3892 @classmethod 3893 def realize_into(cls, src, dst): 3894 # Attempt to turn this into a ReinterpretView rather than assert. 3895 # This has concessions around layout, as as_storage_and_layout 3896 # can cause us to go from flexible to fixed layout. 3897 if not isinstance(dst, ReinterpretView): 3898 if is_storage_and_layout(dst): 3899 storage, layout = as_storage_and_layout(dst) 3900 dst = ReinterpretView(storage, layout) 3901 assert isinstance(dst, ReinterpretView), dst 3902 if isinstance(src, TensorBox): 3903 # unwrap a TensorBox 3904 return cls.realize_into(src.data, dst) 3905 if isinstance(src, StorageBox): 3906 src.realize() 3907 # ExternKernelAlloc has specific requirements for output layout, should create a copy 3908 assert hasattr(src.data, "layout") 3909 if cls.can_realize_into_without_copy(src): 3910 src.data.layout = NonOwningLayout(dst) 3911 return src.data 3912 # introduce a copy 3913 pw = Pointwise.create( 3914 device=src.get_device(), 3915 dtype=src.get_dtype(), 3916 inner_fn=src.make_loader(), 3917 ranges=[ 3918 V.graph.sizevars.guard_equals(a, b) 3919 for a, b in zip(src.get_size(), dst.get_size()) 3920 ], 3921 ) 3922 return cls.realize_into(pw, dst) 3923 3924 def should_allocate(self): 3925 return True 3926 3927 3928def get_aten_cpp_kernel_name(kernel): 3929 # Calling with the default kernel name can lead to ambiguous behavior like the following example. 3930 # repeat_interleave(const at::Tensor & repeats, c10::optional<int64_t> output_size=c10::nullopt) 3931 # repeat_interleave(const at::Tensor & self, int64_t repeats, 3932 # c10::optional<int64_t> dim=c10::nullopt, c10::optional<int64_t> output_size=c10::nullopt) 3933 if not isinstance(kernel, torch._ops.OpOverload) or kernel.namespace != "aten": 3934 return None 3935 opname = ( 3936 kernel.__name__.split(".")[0] 3937 if kernel._overloadname == "default" 3938 else kernel.__name__.replace(".", "_") 3939 ) 3940 return f"at::_ops::{opname}::call" 3941 3942 3943@dataclasses.dataclass 3944class ExternKernel(InputsKernel): 3945 constant_args: Tuple[Any, ...] = () 3946 kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) 3947 output_view: Optional[ReinterpretView] = None 3948 python_kernel_name: Optional[str] = None 3949 cpp_kernel_name: Optional[str] = None 3950 # FIXME: in some cases we sill need to explicitly pass in ordered_kwargs_for_cpp_kernel 3951 # We shouldn't need to do this since the information can be retrieved from op_overload._schema. 3952 ordered_kwargs_for_cpp_kernel: Iterable[str] = dataclasses.field( 3953 default_factory=list 3954 ) 3955 op_overload: Optional[ 3956 Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator] 3957 ] = None 3958 arg_properties: Optional[List[Dict[str, Any]]] = None 3959 kwarg_properties: Optional[Dict[str, Dict[str, Any]]] = None 3960 unbacked_bindings: Dict[sympy.Symbol, pytree.KeyPath] = dataclasses.field( 3961 default_factory=dict 3962 ) 3963 3964 def __init__( 3965 self, 3966 name, 3967 layout, 3968 inputs, 3969 constant_args=(), 3970 kwargs=None, 3971 output_view=None, 3972 python_kernel_name=None, 3973 cpp_kernel_name=None, 3974 ordered_kwargs_for_cpp_kernel=(), 3975 op_overload=None, 3976 ): 3977 super().__init__( 3978 name, 3979 layout, 3980 inputs, 3981 ) 3982 self.constant_args = constant_args 3983 self.kwargs = kwargs if kwargs else {} 3984 self.output_view = output_view 3985 self.python_kernel_name = python_kernel_name 3986 # If cpp_kernel_name is None, we will try to construct it from op_overload 3987 self.cpp_kernel_name = cpp_kernel_name or get_aten_cpp_kernel_name(op_overload) 3988 self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel 3989 self.op_overload = op_overload 3990 self.collect_arg_kwarg_properties() 3991 self.unbacked_bindings = {} 3992 self.fx_node = V.graph.current_node 3993 3994 def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: 3995 return set() 3996 3997 def collect_arg_kwarg_properties(self): 3998 # if self.op_overload is torch._ops.OpOverload, we can use its schema to collect additional 3999 # information for args and kwargs, e.g. type and default value, to help with the cpp wrapper codegen 4000 self.arg_properties = ( 4001 [ 4002 { 4003 "name": x.name, 4004 "type": x.real_type, 4005 "default_value": x.default_value, 4006 } 4007 for x in self.op_overload._schema.arguments 4008 if not x.kwarg_only 4009 ] 4010 if isinstance(self.op_overload, torch._ops.OpOverload) 4011 else [{} for i in range(len(self.inputs))] 4012 ) 4013 self.allarg_properties = ( 4014 { 4015 x.name: {"type": x.real_type, "default_value": x.default_value} 4016 for x in self.op_overload._schema.arguments 4017 } 4018 if isinstance(self.op_overload, torch._ops.OpOverload) 4019 else {} 4020 ) 4021 # FIXME: self.kwargs does not always match kwargs defined in schema, so sometimes 4022 # ordered_kwargs_for_cpp_kernel is explicilty passed in. 4023 if ( 4024 isinstance(self.op_overload, torch._ops.OpOverload) 4025 and not self.ordered_kwargs_for_cpp_kernel 4026 ): 4027 self.ordered_kwargs_for_cpp_kernel = [ 4028 x.name for x in self.op_overload._schema.arguments if x.kwarg_only 4029 ] 4030 4031 def fill_non_provided_args(self, args, kwargs, convert_val_to_str=False): 4032 # Previously, we want to maintain forward-compatibility by skipping 4033 # default args in the serialized artifacts in fbcode. However, 4034 # some of our shim interfaces require default values being set. 4035 # Discussed with Sherlock offline and we decided to allow serializing 4036 # default args into the C++ wrapper code for now. We will refine this 4037 # part if we see real FC requirement. More details related to FC 4038 # can be found at: 4039 # https://docs.google.com/document/d/1FzWm-sHYwmRi3x_g036kOxd99KaYquUsA-L5JwOn8ys/edit?usp=sharing 4040 assert isinstance(args, (list, tuple)) 4041 if isinstance(args, tuple): 4042 args = list(args) 4043 assert self.arg_properties, "ExternKernel.arg_properties should not be empty" 4044 4045 n_args = len(args) 4046 n_pos_args = len(self.arg_properties) 4047 # For cpp wrapper, if some positional args are not provided, we need to check 4048 # if they're in the kwargs or use their default value 4049 if n_args < n_pos_args: 4050 log.debug( 4051 "%s has %d unprovided positional arguments. " 4052 "Will check if they are in the keyword arguments or will use default values.", 4053 self.op_overload, 4054 n_pos_args - n_args, 4055 ) 4056 for i in range(n_args, n_pos_args): 4057 arg_name = self.arg_properties[i]["name"] 4058 args.append( 4059 kwargs[arg_name] 4060 if arg_name in kwargs 4061 else self.arg_properties[i]["default_value"] 4062 ) 4063 return args 4064 4065 def decide_layout(self): 4066 if isinstance(self.layout, FlexibleLayout): 4067 self.apply_constraint() 4068 self.freeze_layout() 4069 4070 def codegen_comment(self, wrapper): 4071 origin_str, detailed_origin_str = get_kernel_metadata(self, wrapper) 4072 if origin_str: 4073 wrapper.writeline(origin_str) 4074 4075 def codegen(self, wrapper): 4076 raise NotImplementedError 4077 4078 def get_kernel_name(self): 4079 return ( 4080 ( 4081 V.graph.wrapper_code.get_c_shim_func_name(self.cpp_kernel_name) # type: ignore[attr-defined] 4082 if config.abi_compatible 4083 else self.cpp_kernel_name 4084 ) 4085 if V.graph.cpp_wrapper 4086 else self.python_kernel_name 4087 ) 4088 4089 @staticmethod 4090 def copy_input(x): 4091 pw = Pointwise.create( 4092 device=x.get_device(), 4093 dtype=x.get_dtype(), 4094 inner_fn=x.make_loader(), 4095 ranges=x.get_size(), 4096 origin_node=x.get_origin_node(), 4097 traceback=x.get_traceback(), 4098 ) 4099 pw.realize() 4100 return pw 4101 4102 @classmethod 4103 def process_kernel( 4104 cls, kernel, *args, **kwargs 4105 ) -> Tuple[ 4106 Any, 4107 List[Any], 4108 List[Any], 4109 Callable[[Any, Any], Any], 4110 Optional[Dict[sympy.Symbol, pytree.KeyPath]], 4111 ]: 4112 binded_args = {"args": args, "kwargs": kwargs} 4113 4114 args_flat, args_spec = pytree.tree_flatten(binded_args) 4115 4116 is_arg_tensor = [] 4117 tensor_args = [] 4118 non_tensor_args: List[Any] = [] 4119 for arg in args_flat: 4120 is_arg_tensor.append(isinstance(arg, IRNode)) 4121 if is_arg_tensor[-1]: 4122 tensor_args.append(arg) 4123 else: 4124 if isinstance(arg, sympy.Expr): 4125 arg = V.graph.sizevars.shape_env.create_symintnode(arg, hint=None) 4126 non_tensor_args.append(arg) 4127 4128 def unflatten_args(new_tensor_args, new_non_tensor_args): 4129 result = [] 4130 it_tensors = iter(new_tensor_args) 4131 it_non_tensors = iter(new_non_tensor_args) 4132 for is_tensor in is_arg_tensor: 4133 if is_tensor: 4134 result.append(next(it_tensors)) 4135 else: 4136 result.append(next(it_non_tensors)) 4137 r = pytree.tree_unflatten(result, args_spec) 4138 return r.get("args", []), r.get("kwargs", {}) 4139 4140 tensor_args = [cls.realize_input(x) for x in tensor_args] 4141 4142 # freeze layout otherwise our output stride calculation might 4143 # become incorrect 4144 for x in tensor_args: 4145 if is_storage_and_layout(x): 4146 as_storage_and_layout(x, freeze=True) 4147 4148 # Rerun fake tensor propagation, because Inductor may have changed the 4149 # strides of inputs and we need to determine accurately what the 4150 # output stride will be. 4151 example_args: List[Union[torch.Tensor, torch._C.ScriptObject]] = [] 4152 4153 # We need to retain the constant values of fake tensors that we originally 4154 # propagated the graph with, because for some operators running without a 4155 # constant would trigger an error / DataDependentException 4156 for x in tensor_args: 4157 if x.get_name() in V.graph.constants: 4158 example_args.append(V.graph.constants[x.get_name()]) 4159 elif x.get_name() in V.graph.torchbind_constants: 4160 example_args.append(V.graph.torchbind_constants[x.get_name()]) 4161 else: 4162 example_args.append(ir_node_to_tensor(x, guard_shape=True)) 4163 4164 new_args, new_kwargs = unflatten_args(example_args, non_tensor_args) 4165 example_output = kernel(*new_args, **new_kwargs) 4166 4167 unbacked_bindings: Optional[Dict[sympy.Symbol, pytree.KeyPath]] = None 4168 if shape_env := V.fake_mode.shape_env: 4169 rebind_unbacked(shape_env, V.current_node, example_output) 4170 unbacked_bindings = compute_unbacked_bindings( 4171 shape_env, example_output, V.current_node.meta.get("val") 4172 ) 4173 4174 example_out_li = ( 4175 [example_output] 4176 if not isinstance(example_output, (list, tuple)) 4177 else example_output 4178 ) 4179 for t in example_out_li: 4180 if isinstance(t, torch.Tensor) and t.is_sparse: 4181 msg = "sparsity not handled. Please file issue for sparse inference weights." 4182 if stack_trace := V.graph.current_node.meta.get("stack_trace", None): 4183 msg = f"{msg} Found from : \n {stack_trace}" 4184 V.graph.disable_cudagraphs_reason = msg 4185 4186 return ( 4187 example_output, 4188 tensor_args, 4189 non_tensor_args, 4190 unflatten_args, 4191 unbacked_bindings, 4192 ) 4193 4194 @classmethod 4195 def convert_to_reinterpret_view(cls, x): 4196 """ 4197 In order to pass this to an extern kernel we need a 4198 ReinterpretView not a View. This allows us to avoid some 4199 unneeded copies. 4200 """ 4201 assert isinstance(x, BaseView) 4202 if isinstance(x, ReinterpretView): 4203 return x 4204 4205 # NOTE: Don't use extract_read_writes here as it fails when 4206 # make_loader() inlines the computation 4207 x_unwrap_view = x.unwrap_view() 4208 x_unwrap_view_fx_node = V.graph.get_buffer( 4209 x_unwrap_view.get_name() 4210 ).get_origin_node() 4211 # Prefer channels last format according to how the format is set from eager. 4212 if ( 4213 x_unwrap_view_fx_node is not None 4214 and "val" in x_unwrap_view_fx_node.meta 4215 and isinstance(x_unwrap_view.layout, FlexibleLayout) 4216 and ( 4217 x_unwrap_view_fx_node.meta["val"].is_contiguous( 4218 memory_format=torch.channels_last 4219 ) 4220 or x_unwrap_view_fx_node.meta["val"].is_contiguous( 4221 memory_format=torch.channels_last_3d 4222 ) 4223 ) 4224 ): 4225 x_unwrap_view.freeze_layout_with_same_order( 4226 make_channels_last_strides_for(x_unwrap_view.get_size()) 4227 ) 4228 else: 4229 x_unwrap_view.freeze_layout() 4230 4231 index_args, var_ranges = dependencies.index_vars_squeeze( 4232 x.get_size(), prefix="r" 4233 ) 4234 range_vars = index_args[0] 4235 index = x.make_indexer()(range_vars) 4236 4237 index = V.graph.sizevars.simplify_with_ranges(index, var_ranges) 4238 strides = V.graph.sizevars.stride_vars(index, range_vars) 4239 offset = V.graph.sizevars.offset_var(index, range_vars) 4240 expected = sympy_dot(range_vars, strides) + offset 4241 4242 if index != expected: 4243 log.debug( 4244 "convert_to_reinterpret_view failed: stride=%s offset=%s index=%s", 4245 strides, 4246 offset, 4247 index, 4248 ) 4249 raise NotImplementedError 4250 4251 return ReinterpretView( 4252 data=x.data, 4253 layout=FixedLayout( 4254 device=x.get_device(), 4255 dtype=x.get_dtype(), 4256 size=x.get_size(), 4257 stride=strides, 4258 offset=offset, 4259 ), 4260 ) 4261 4262 @classmethod 4263 def realize_input(cls, x): 4264 if x is None: 4265 return NoneAsConstantBuffer() 4266 if isinstance(x, (sympy.Expr, sympy.logic.boolalg.Boolean, int)): 4267 return ShapeAsConstantBuffer(x) 4268 if isinstance(x, Constant): 4269 return V.graph.add_tensor_constant( 4270 torch.tensor(x.value, dtype=x.get_dtype(), device=x.get_device()) 4271 ) 4272 if isinstance(x, ConstantBuffer): 4273 return x 4274 if isinstance(x, TensorBox): 4275 return cls.realize_input(x.data) 4276 if isinstance(x, ReinterpretView): 4277 return ReinterpretView(cls.realize_input(x.data), x.get_layout()) 4278 if isinstance(x, BaseView): 4279 x.realize() 4280 if is_storage_and_layout(x.unwrap_view()): 4281 try: 4282 return cls.convert_to_reinterpret_view(x) 4283 except NotImplementedError: 4284 pass 4285 if isinstance(x, StorageBox): 4286 # TODO(jansel): impose layout preference on realized buffer 4287 x.realize() 4288 return x 4289 if isinstance(x, TorchBindObject): 4290 return x 4291 return cls.copy_input(x) 4292 4293 @classmethod 4294 def require_stride1(cls, x): 4295 if is_storage_and_layout(x): 4296 if len(x.get_stride()) == 0: 4297 return x 4298 for stride in x.get_stride(): 4299 if stride == 1: 4300 return x 4301 return cls.copy_input(x) 4302 4303 @classmethod 4304 def require_stride_order(cls, x, order, allow_padding=False): 4305 if x.get_numel() == 0: # Layout doesn't matter 4306 return x 4307 4308 # require x to have the layout as strided_ordered as order 4309 if is_storage_and_layout(x): 4310 while isinstance(x.get_layout(), NonOwningLayout): 4311 x = x.get_layout().view 4312 if isinstance(x.get_layout(), FlexibleLayout): 4313 # If the the FlexibleLayout already has the size and stride in the required order, 4314 # freeze it to a FixedLayout by using its current size and stride. 4315 # The behavior of using its current size and stride or the given order can be different 4316 # if the size and stride has ambiguilty, for example for a 4D input where the iC = 1: 4317 # size=[s0, 1, 28, 28], stride=[784, 784, 28, 1]. If the required order is [3, 0, 2, 1] (channels last), 4318 # the current size and stride already satisfies this order. 4319 # However by freezing it to the required order, the layout will be changed to: 4320 # size=[s0, 1, 28, 28], stride=[784, 1, 28, 1]), which is not actually necessary. 4321 4322 # fix flexiblelayout to be FixedLayout with stride_order 4323 as_storage_and_layout( 4324 x, 4325 freeze=True, 4326 want_contiguous=False, 4327 stride_order=get_stride_order( 4328 V.graph.sizevars.size_hints(x.get_layout().stride) 4329 ) 4330 if is_stride_order_storage_and_layout(x, order) 4331 else order, 4332 allow_padding=allow_padding, 4333 ) 4334 return x 4335 elif isinstance( 4336 x.get_layout(), FixedLayout 4337 ) and x.get_layout().is_stride_ordered(order): 4338 return x 4339 elif isinstance(x.get_layout(), MutationLayoutSHOULDREMOVE): 4340 if isinstance(x.get_layout().real_layout(), FlexibleLayout): 4341 raise AssertionError( 4342 "the MutationLayoutSHOULDREMOVE's real layout shouldn't be FlexibleLayout" 4343 ) 4344 elif isinstance( 4345 x.get_layout().real_layout(), FixedLayout 4346 ) and x.get_layout().real_layout().is_stride_ordered(order): 4347 return x 4348 4349 # TODO - Storage to InputBuffer 4350 if isinstance(x, InputBuffer) and x.get_layout().is_stride_ordered(order): 4351 return x 4352 if ( 4353 isinstance(x, TensorBox) 4354 and isinstance(x.data, BaseView) 4355 and not isinstance(x.data, ReinterpretView) 4356 and is_storage_and_layout(x.unwrap_view()) 4357 and not isinstance(x.unwrap_view().data, ExternKernelAlloc) 4358 ): 4359 try: 4360 x.data = cls.convert_to_reinterpret_view(x.data) 4361 return cls.require_stride_order(x, order, allow_padding=allow_padding) 4362 except NotImplementedError: 4363 pass 4364 x = cls.copy_input(x) 4365 as_storage_and_layout( 4366 x, 4367 freeze=True, 4368 want_contiguous=False, 4369 stride_order=order, 4370 allow_padding=allow_padding, 4371 ) 4372 assert is_stride_order_storage_and_layout(x, order) 4373 return x 4374 4375 @classmethod 4376 def require_channels_last(cls, x): 4377 return cls.require_stride_order(x, NHWC_STRIDE_ORDER) 4378 4379 @classmethod 4380 def require_channels_last_3d(cls, x): 4381 return cls.require_stride_order(x, NHWDC_STRIDE_ORDER) 4382 4383 @classmethod 4384 def require_contiguous(cls, x): 4385 return cls.require_stride_order(x, list(reversed(range(len(x.get_size()))))) 4386 4387 def apply_constraint(self): 4388 pass 4389 4390 def codegen_const_args(self): 4391 if V.graph.cpp_wrapper: 4392 result = [] 4393 for i, x in enumerate(self.constant_args): 4394 idx = len(self.inputs) + i 4395 type_ = ( 4396 self.arg_properties[i].get("type") 4397 if self.arg_properties and idx < len(self.arg_properties) 4398 else None 4399 ) 4400 result.append( 4401 V.graph.wrapper_code.val_to_arg_str(x, type_) # type: ignore[arg-type] 4402 ) 4403 return result 4404 else: 4405 return map(V.graph.wrapper_code.val_to_arg_str, self.constant_args) 4406 4407 def codegen_args(self): 4408 args = [] 4409 for i, x in enumerate(self.inputs): 4410 if isinstance(x, list): 4411 names = [i.codegen_reference() for i in x] 4412 codegen_reference = f'[{", ".join(names)}]' 4413 args.append(codegen_reference) 4414 else: 4415 if V.graph.cpp_wrapper: 4416 assert self.arg_properties and i < len( 4417 self.arg_properties 4418 ), "Invalid access to ExternKernel.arg_properties" 4419 type_ = self.arg_properties[i].get("type") 4420 args.append( 4421 V.graph.wrapper_code.val_to_arg_str( # type: ignore[arg-type] 4422 x, type_ 4423 ) 4424 ) 4425 else: 4426 args.append(x.codegen_reference()) 4427 args.extend(self.codegen_const_args()) 4428 return args 4429 4430 def get_kwargs_value(self, arg_name): 4431 if arg_name in self.kwargs: 4432 return self.kwargs.get(arg_name) 4433 if self.allarg_properties and self.allarg_properties.get(arg_name): 4434 return self.allarg_properties.get(arg_name).get("default_value") # type: ignore[union-attr] 4435 else: 4436 raise AssertionError(f"{arg_name} not in self.allarg_properties") 4437 4438 def codegen_kwargs(self, skip_out=False): 4439 if V.graph.cpp_wrapper: 4440 kwargs = [] 4441 for arg_name in self.ordered_kwargs_for_cpp_kernel: 4442 if skip_out and arg_name == "out": 4443 # ExternKernelOut has its own logic for inserting the out parameter 4444 continue 4445 4446 v = self.get_kwargs_value(arg_name) 4447 if isinstance(v, sympy.Expr): 4448 kwargs.append(v) 4449 else: 4450 type_ = ( 4451 self.allarg_properties.get(arg_name).get("type") # type: ignore[union-attr] 4452 if self.allarg_properties and arg_name in self.allarg_properties 4453 else None 4454 ) 4455 kwargs.append( 4456 V.graph.wrapper_code.val_to_arg_str( # type: ignore[arg-type] 4457 v, type_ 4458 ) 4459 ) 4460 else: 4461 kwargs = [ 4462 f"{k}={V.graph.wrapper_code.val_to_arg_str(v)}" # type: ignore[misc] 4463 for k, v in self.kwargs.items() 4464 ] 4465 return kwargs 4466 4467 def codegen_size_asserts(self, wrapper): 4468 if config.size_asserts and not V.graph.cpp_wrapper: 4469 # comparing strides for 0 size tensor is tricky. Ignore them for now. 4470 if sympy_product(self.get_size()) == 0: 4471 return 4472 size = V.graph.wrapper_code.codegen_shape_tuple(self.get_size()) 4473 stride = V.graph.wrapper_code.codegen_shape_tuple(self.get_stride()) 4474 wrapper.writeline( 4475 f"assert_size_stride({self.get_name()}, {size}, {stride})" 4476 ) 4477 4478 def get_group_stride(self): 4479 """ 4480 get output sizes and strides, for template_codegen 4481 """ 4482 _size = self.get_size() 4483 _stride = self.get_stride() 4484 # iter_ranges = _size of output tensor, reduce_range = [] because no reduction 4485 return [_size, []], _stride 4486 4487 def canonicalize(self): 4488 """ 4489 Manually get canonicalization of the output index 4490 """ 4491 # manually generate index formula for conv 4492 sizevars = V.graph.sizevars 4493 sizes = self.get_size() 4494 strides = self.get_stride() 4495 strides = [sizevars.size_hint(x) for x in strides] 4496 # TODO: I can't tell if the symbols here are temporary 4497 index_vars = [sympy_index_symbol(f"d{i}") for i in range(len(sizes))] 4498 # reorder index vars according to stride 4499 index_order = sorted(range(len(strides)), key=strides.__getitem__, reverse=True) 4500 lookup = {pos: idx for idx, pos in enumerate(index_order)} 4501 order = [lookup[i] for i in range(len(lookup))] 4502 index_vars = [index_vars[i] for i in order] 4503 indexer = self.make_indexer() 4504 index = indexer(index_vars) 4505 4506 new_sizes, reindex, prune = V.graph.sizevars._simplify_loops( 4507 index_vars, sizes, [index] 4508 ) 4509 4510 # assign new variables each dimension to deal with numbering mismatches 4511 # d0, d1, d2 could become d0, d2 -- which won't match d0, d1 4512 _, add_var = var_builder("c") 4513 replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes]))) 4514 4515 index = sympy_subs(sympy.expand(index), replacement) # type: ignore[arg-type] 4516 return index, tuple(new_sizes) 4517 4518 def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: 4519 # NB: It's not necessary to check regular inputs as we automatically 4520 # have dependencies on them 4521 r = set() 4522 for arg in self.constant_args: 4523 r |= maybe_free_unbacked_symbols(arg) 4524 for arg in self.kwargs.values(): 4525 r |= maybe_free_unbacked_symbols(arg) 4526 return r 4527 4528 def __str__(self): 4529 kernel_name = getattr(self, "python_kernel_name", None) 4530 lines = [ 4531 f"python_kernel_name={kernel_name!r}", 4532 ] 4533 lines += [ 4534 f"{field.name}={getattr(self, field.name)}" 4535 for field in dataclasses.fields(self) 4536 ] 4537 lines.append(f"origin_node={self.origin_node!r}") 4538 return self.str_helper(lines) 4539 4540 __repr__ = __str__ 4541 4542 4543@dataclasses.dataclass 4544class ExternKernelOut(ExternKernel): 4545 def codegen(self, wrapper): 4546 self.codegen_comment(wrapper) 4547 args = [*self.codegen_args(), *self.codegen_kwargs(skip_out=True)] 4548 wrapper.generate_extern_kernel_out( 4549 self.get_kernel_name(), 4550 self.codegen_reference(), 4551 self.output_view.codegen_reference() if self.output_view else None, 4552 args, 4553 ) 4554 4555 def __init__( 4556 self, 4557 layout, 4558 inputs, 4559 constant_args=(), 4560 kwargs=None, 4561 output_view=None, 4562 python_kernel_name=None, 4563 cpp_kernel_name=None, 4564 ordered_kwargs_for_cpp_kernel=(), 4565 op_overload=None, 4566 ): 4567 super().__init__( 4568 None, 4569 layout, 4570 self.unwrap_storage(inputs), 4571 constant_args, 4572 kwargs or {}, 4573 None, 4574 python_kernel_name, 4575 cpp_kernel_name, 4576 ordered_kwargs_for_cpp_kernel, 4577 op_overload, 4578 ) 4579 self.name = V.graph.register_buffer(self) 4580 4581 def should_allocate(self): 4582 return True 4583 4584 4585class RandomSeeds(ExternKernelOut): 4586 def __init__(self, count: int, device: torch.device): 4587 limits = torch.iinfo(torch.int64) 4588 super().__init__( 4589 layout=FixedLayout( 4590 device=device, 4591 dtype=torch.int64, 4592 size=[count], 4593 ), 4594 inputs=[], 4595 constant_args=[limits.min, limits.max, [count]], 4596 python_kernel_name="aten.randint.low_out", 4597 # FIXME: Ideally we should only use at::_ops::randint_low_out::call here, 4598 # but the signature is different from is at::randint_out. Again, 4599 # we can simplify the code when only keeping an ABI-compatible version. 4600 cpp_kernel_name="at::_ops::randint_low_out::call" 4601 if config.abi_compatible 4602 else "at::randint_out", 4603 op_overload=aten.randint.low_out, 4604 ) 4605 4606 4607class ExternKernelAlloc(ExternKernel): 4608 def codegen(self, wrapper): 4609 self.codegen_comment(wrapper) 4610 args = [*self.codegen_args(), *self.codegen_kwargs()] 4611 V.graph.wrapper_code.generate_extern_kernel_alloc(self, args) 4612 if isinstance(self.layout, Layout): 4613 self.codegen_size_asserts(wrapper) 4614 4615 def __init__( 4616 self, 4617 layout, 4618 inputs, 4619 constant_args=(), 4620 kwargs=None, 4621 python_kernel_name=None, 4622 cpp_kernel_name=None, 4623 ordered_kwargs_for_cpp_kernel=(), 4624 op_overload=None, 4625 ): 4626 super().__init__( 4627 None, 4628 layout, 4629 self.unwrap_storage(inputs), 4630 constant_args, 4631 kwargs or {}, 4632 None, 4633 python_kernel_name, 4634 cpp_kernel_name, 4635 ordered_kwargs_for_cpp_kernel, 4636 op_overload, 4637 ) 4638 self.name = V.graph.register_buffer(self) 4639 4640 def should_allocate(self): 4641 return False 4642 4643 def apply_constraint(self): 4644 raise NotImplementedError 4645 4646 4647class UserDefinedTritonKernel(ExternKernel): 4648 def get_kernel_and_configs(self): 4649 from triton.runtime.autotuner import Autotuner 4650 4651 from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table 4652 4653 kernel = kernel_side_table.get_kernel(self.kernel_idx) 4654 configs = [] 4655 if isinstance(kernel, Autotuner): 4656 configs = kernel.configs 4657 kernel = kernel.fn 4658 return kernel, configs 4659 4660 def codegen(self, wrapper): 4661 kernel, configs = self.get_kernel_and_configs() 4662 4663 # Definition of kernel 4664 new_name, triton_meta = wrapper.define_user_defined_triton_kernel( 4665 kernel, configs, self.kwargs 4666 ) 4667 4668 args = self.codegen_kwargs() 4669 arg_types = [] 4670 if V.graph.cpp_wrapper: 4671 # in C++ wrapper, we don't pass constexpr args, as they don't 4672 # get added as parameters to the PTX code compiled from the 4673 # user-defined Triton kernel (only non-constexpr args do) 4674 args = [arg for i, arg in enumerate(args) if i not in kernel.constexprs] 4675 # cpp wrapper needs arg type info for codegen 4676 for arg_name in self.ordered_kwargs_for_cpp_kernel: 4677 val = self.get_kwargs_value(arg_name) 4678 arg_types.append( 4679 val.get_dtype() if hasattr(val, "get_dtype") else type(val) 4680 ) 4681 arg_types = [ 4682 t for i, t in enumerate(arg_types) if i not in kernel.constexprs 4683 ] 4684 4685 # Call to kernel 4686 self.codegen_comment(wrapper) 4687 wrapper.generate_user_defined_triton_kernel( 4688 new_name, self.grid, configs, args, triton_meta, arg_types 4689 ) 4690 4691 def should_allocate(self): 4692 return False 4693 4694 def has_side_effects(self): 4695 # UserDefinedTritonKernel does not return anything, but rather 4696 # modifies input in place, do not let it get DCEd 4697 return True 4698 4699 def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: 4700 # add unbacked symbols used in the grid to the ones used 4701 # in the kwargs (the latter is generated by ExternKernel) 4702 return super().get_unbacked_symbol_uses() | free_unbacked_symbols(self.grid) 4703 4704 def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: 4705 return set() 4706 4707 def get_mutation_names(self): 4708 # NB: Inductor only allows a node to mutate 0 or 1 buffers. 4709 # To get around that, we create MutationOutputs which marks their 4710 # assigned input as mutable, thus, adhering to Inductor's constraint. 4711 return [] 4712 4713 def __init__(self, *, kernel_idx, grid, kernel_args): 4714 inputs = [] 4715 kwargs = dict() 4716 constant_args = [] 4717 for k, v in kernel_args.items(): 4718 if isinstance(v, TensorBox): 4719 t = InputsKernel.unwrap_storage_for_input(self.realize_input(v)) 4720 inputs.append(t) 4721 kwargs[k] = t 4722 else: 4723 constant_args.append(v) 4724 kwargs[k] = v 4725 4726 assert len(inputs) != 0 4727 device = inputs[0].get_device() 4728 4729 super().__init__( 4730 None, 4731 NoneLayout(device), # type: ignore[arg-type] 4732 inputs, 4733 tuple(constant_args), 4734 kwargs, 4735 ) 4736 self.name = V.graph.register_buffer(self) 4737 self.kernel_idx = kernel_idx 4738 self.grid = grid 4739 4740 kernel, configs = self.get_kernel_and_configs() 4741 # If we are autotuning, not all arguments will be passed 4742 self.ordered_kwargs_for_cpp_kernel = [ 4743 arg for arg in kernel.arg_names if arg in kernel_args 4744 ] 4745 4746 from torch._higher_order_ops.triton_kernel_wrap import identify_mutated_tensors 4747 4748 autotuned_kwargs = configs[0].kwargs if len(configs) > 0 else {} 4749 self.mutable_args = [ 4750 kernel_args[key] 4751 for key in identify_mutated_tensors( 4752 kernel, {**kernel_args, **autotuned_kwargs} 4753 ) 4754 ] 4755 mark_node_as_mutating(self, *self.mutable_args) 4756 4757 def get_inputs_that_alias_output(self): 4758 return [i.get_name() for i in self.mutable_args] 4759 4760 4761def mark_node_as_mutating(cur_buffer, *mutated_nodes: IRNode): 4762 """ 4763 Allows ops in mutated_nodes to be marked as being mutated as well as 4764 indicates to the scheduler that these ops depend on cur_buffer. 4765 4766 NB: Use this instead of directly constructing MutationOutput 4767 """ 4768 for node in mutated_nodes: 4769 assert isinstance( 4770 node, IRNode 4771 ), f"{node} node is type {type(node)} and is not an IRNode" 4772 V.graph.mark_buffer_mutated(node.get_name()) 4773 MutationOutput(node.get_layout(), node, cur_buffer) 4774 4775 4776class MutationOutput(ExternKernel): 4777 def get_mutation_names(self): 4778 return [self.inputs[0].get_name()] 4779 4780 def __init__(self, layout, mutated_node, node_doing_mutating): 4781 # NB: Do not directly construct this - use `mark_node_as_mutating` 4782 super().__init__(None, layout, [mutated_node, node_doing_mutating], ()) 4783 self.node_doing_mutating = node_doing_mutating 4784 self.name = V.graph.register_buffer(self) 4785 4786 def should_allocate(self): 4787 return False 4788 4789 def is_no_op(self): 4790 return True 4791 4792 def has_side_effects(self): 4793 return True 4794 4795 def get_inputs_that_alias_output(self): 4796 return [self.inputs[0].get_name()] 4797 4798 4799class InplaceBernoulliFallback(ExternKernel): 4800 """ 4801 This needs to be a custom class to handle mutation properly 4802 """ 4803 4804 def codegen(self, wrapper): 4805 (x,) = (t.codegen_reference() for t in self.inputs) 4806 4807 if V.graph.cpp_wrapper and config.abi_compatible: 4808 # Inductor doesn't really support aten Generator, so the Generator kwarg is always NULL here, 4809 # which needs to be explicitly generated for cpp wrapper 4810 wrapper.writeline( 4811 f"{self.get_kernel_name()}({x}, {', '.join(map(repr, self.constant_args))}, NULL){wrapper.ending}" 4812 ) 4813 else: 4814 wrapper.writeline( 4815 f"{self.get_kernel_name()}({x}, {', '.join(map(repr, self.constant_args))}){wrapper.ending}" 4816 ) 4817 4818 def should_allocate(self): 4819 return False 4820 4821 def get_mutation_names(self): 4822 return [self.inputs[0].get_name()] 4823 4824 def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: 4825 return set() 4826 4827 def __init__(self, op_overload, x, *constant_args): 4828 super().__init__( 4829 None, 4830 NoneLayout(x.get_device()), # type: ignore[arg-type] 4831 self.unwrap_storage([x]), 4832 constant_args, 4833 op_overload=op_overload, 4834 ) 4835 self.name = V.graph.register_buffer(self) 4836 self.python_kernel_name = "aten.bernoulli_" 4837 if not config.abi_compatible: 4838 # TODO: this should be simplified once we switch to ABI-compatible only 4839 self.cpp_kernel_name = "at::native::bernoulli_" 4840 mark_node_as_mutating(self, x) 4841 4842 4843# Used to deal with torch.complex types 4844class InplaceCopyFallback(ExternKernel): 4845 """ 4846 This needs to be a custom class to handle mutation properly 4847 """ 4848 4849 def codegen(self, wrapper): 4850 (dst, src, non_blocking) = self.codegen_args() 4851 wrapper.writeline( 4852 f"{self.get_kernel_name()}({dst}, {src}, {non_blocking}){wrapper.ending}" 4853 ) 4854 4855 def should_allocate(self): 4856 return False 4857 4858 def get_mutation_names(self): 4859 return [self.inputs[0].get_name()] 4860 4861 def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: 4862 return set() 4863 4864 def __init__( 4865 self, 4866 layout, 4867 inputs, 4868 constant_args, 4869 ): 4870 super().__init__( 4871 None, 4872 layout, 4873 inputs, 4874 constant_args, 4875 python_kernel_name="aten.copy_", 4876 cpp_kernel_name=( 4877 "aoti_torch_copy_" if config.abi_compatible else "at::_ops::copy_::call" 4878 ), 4879 ) 4880 self.name = V.graph.register_buffer(self) 4881 4882 @classmethod 4883 def create(cls, dst, src, non_blocking: bool = False): 4884 inputs = [cls.realize_input(t) for t in [dst, src]] 4885 constant_args = (non_blocking,) 4886 result = InplaceCopyFallback( 4887 NoneLayout(dst.get_device()), # type: ignore[arg-type] 4888 inputs, 4889 constant_args, 4890 ) 4891 mark_node_as_mutating(result, dst) 4892 return result 4893 4894 4895class MutatingFirstArgExternKernel(ExternKernel): 4896 """ 4897 This needs to be a custom class to handle mutation properly 4898 """ 4899 4900 def codegen(self, wrapper): 4901 argrefs = [ 4902 *(t.codegen_reference() for t in self.inputs), 4903 *map(repr, self.constant_args), 4904 ] 4905 wrapper.writeline( 4906 f"{self.get_kernel_name()}({', '.join(argrefs)}){wrapper.ending}" 4907 ) 4908 4909 def should_allocate(self): 4910 return False 4911 4912 def get_mutation_names(self): 4913 return [self.inputs[0].get_name()] 4914 4915 def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: 4916 return set() 4917 4918 def has_side_effects(self): 4919 return True 4920 4921 4922class ResizeStorageBytes(MutatingFirstArgExternKernel): 4923 def __init__(self, variable, new_size): 4924 assert isinstance(new_size, int), "TODO: dynamic shapes" 4925 super().__init__( 4926 None, 4927 NoneLayout(variable.get_device()), # type: ignore[arg-type] 4928 self.unwrap_storage([variable]), 4929 constant_args=(new_size,), 4930 ) 4931 V.graph.mark_buffer_mutated(variable.get_name()) 4932 self.name = V.graph.register_buffer(self) 4933 self.python_kernel_name = "inductor_ops.resize_storage_bytes_" 4934 self.cpp_kernel_name = "torch::inductor::resize_storage_bytes_" 4935 V.graph.never_reuse_buffers.add(variable.data.get_name()) 4936 mark_node_as_mutating(self, variable) 4937 4938 4939class SetSourceTensorKernel(ExternKernelAlloc): 4940 def __init__(self, self_tensor, storage_tensor): 4941 self_tensor.freeze_layout() 4942 super().__init__( 4943 self_tensor.get_layout(), 4944 [self_tensor, storage_tensor], 4945 python_kernel_name="torch.ops.aten.set_.source_Tensor", 4946 ) 4947 V.graph.never_reuse_buffers.add(self_tensor.data.get_name()) 4948 V.graph.never_reuse_buffers.add(storage_tensor.get_name()) 4949 V.graph.never_reuse_buffers.add(self.get_name()) 4950 mark_node_as_mutating(self, self_tensor, storage_tensor) 4951 4952 def get_inputs_that_alias_output(self): 4953 return [self.inputs[0].get_name(), self.inputs[1].get_name()] 4954 4955 def get_mutation_names(self): 4956 return [self.inputs[1].get_name()] 4957 4958 def has_side_effects(self): 4959 return True 4960 4961 4962class ScatterFallback(ExternKernel): 4963 """ 4964 This needs to be a custom class to handle mutation properly. 4965 This class handles both aten.scatter_ and aten.scatter_reduce_. 4966 It also handle the case `src` being a scalar properly. 4967 """ 4968 4969 def codegen(self, wrapper): 4970 reduce = self.kwargs["reduce"] 4971 if V.graph.cpp_wrapper: 4972 # Follow aten/src/ATen/native/ReductionType.h:get_operator_enum 4973 get_operator_enum = {"add": "sum", "multiply": "prod"} 4974 if reduce in get_operator_enum: 4975 reduce = get_operator_enum[reduce] 4976 4977 if self.src_is_tensor: 4978 (x, index, src) = (t.codegen_reference() for t in self.inputs) 4979 else: 4980 (x, index) = (t.codegen_reference() for t in self.inputs) 4981 src = self.constant_args[1] 4982 wrapper.generate_scatter_fallback( 4983 x, 4984 [x, self.constant_args[0], index, src], 4985 self.cpp_kernel_name, 4986 self.python_kernel_name, 4987 self.src_is_tensor, 4988 reduce, 4989 self.codegen_kwargs(), 4990 ) 4991 4992 def should_allocate(self): 4993 return False 4994 4995 def get_mutation_names(self): 4996 return [self.inputs[0].get_name()] 4997 4998 def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: 4999 return set() 5000 5001 def __init__( 5002 self, 5003 op_overload, 5004 x, 5005 dim: int, 5006 index, 5007 src, 5008 *, 5009 reduce: Optional[str] = None, 5010 include_self: bool = True, 5011 ): 5012 self.src_is_tensor = isinstance(src, TensorBox) 5013 5014 constant_args: Tuple[Any, ...] 5015 if self.src_is_tensor: 5016 tensors = [self.realize_input(t) for t in [x, index, src]] 5017 constant_args = (dim,) 5018 else: 5019 tensors = [self.realize_input(t) for t in [x, index]] 5020 constant_args = (dim, src) 5021 5022 super().__init__( 5023 None, 5024 NoneLayout(x.get_device()), # type: ignore[arg-type] 5025 self.unwrap_storage(tensors), 5026 constant_args, 5027 {"reduce": reduce, "include_self": include_self}, 5028 python_kernel_name=str(op_overload), 5029 ordered_kwargs_for_cpp_kernel=["reduce", "include_self"], 5030 op_overload=op_overload, 5031 ) 5032 self.cpp_kernel_name = get_aten_cpp_kernel_name(op_overload) 5033 self.name = V.graph.register_buffer(self) 5034 mark_node_as_mutating(self, x) 5035 5036 5037class IndexPutFallback(ExternKernel): 5038 """ 5039 This needs to be a custom class to handle mutation and indices properly 5040 """ 5041 5042 def codegen(self, wrapper): 5043 (x, values, *valid_indices) = (t.codegen_reference() for t in self.inputs) 5044 indices = [] 5045 iter_valid_indices = iter(valid_indices) 5046 for i, _ in enumerate(self.indices): 5047 if self.indices[i] is not None: 5048 indices.append(next(iter_valid_indices)) 5049 else: 5050 indices.append(V.graph.wrapper_code.none_str) 5051 5052 wrapper.generate_index_put_fallback( 5053 self.get_kernel_name(), x, indices, values, *self.codegen_const_args() 5054 ) 5055 5056 def should_allocate(self): 5057 return False 5058 5059 def get_mutation_names(self): 5060 return [self.inputs[0].get_name()] 5061 5062 def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: 5063 return set() 5064 5065 def __init__(self, op_overload, x, indices, values, accumulate): 5066 self.indices = indices 5067 valid_indices = [i for i in indices if i is not None] 5068 tensors = [self.realize_input(x) for x in [x, values, *valid_indices]] 5069 cpp_kernel_name = ( 5070 "aoti_torch_index_put_out" if config.abi_compatible else "at::index_put_out" 5071 ) 5072 super().__init__( 5073 None, 5074 NoneLayout(x.get_device()), # type: ignore[arg-type] 5075 self.unwrap_storage(tensors), 5076 (accumulate,), 5077 python_kernel_name="aten.index_put_", 5078 cpp_kernel_name=cpp_kernel_name, 5079 op_overload=op_overload, 5080 ) 5081 self.name = V.graph.register_buffer(self) 5082 mark_node_as_mutating(self, x) 5083 5084 5085class DeviceCopy(ExternKernelOut): 5086 @classmethod 5087 def create(cls, x, device): 5088 if ( 5089 not x.is_extern() 5090 and all( 5091 (r.name in V.graph.constants and isinstance(r, dependencies.MemoryDep)) 5092 for r in x.get_reads() 5093 ) 5094 and not config.aot_inductor.use_runtime_constant_folding 5095 ): 5096 return x.constant_to_device(device) 5097 5098 V.graph.add_device_info(device) 5099 V.graph.add_device_info(x.get_device()) 5100 5101 developer_warning("DeviceCopy in input program") 5102 return DeviceCopy( 5103 FlexibleLayout( 5104 device=device, 5105 dtype=x.get_dtype(), 5106 size=x.get_size(), 5107 ), 5108 [cls.realize_input(x)], 5109 ) 5110 5111 def codegen(self, wrapper): 5112 args = self.codegen_args() 5113 assert len(args) == 1 5114 if self.output_view: 5115 wrapper.codegen_device_copy(args[0], self.output_view.codegen_reference()) 5116 else: 5117 wrapper.codegen_device_copy(args[0], self.codegen_reference()) 5118 5119 5120class DynamicScalar(ExternKernel): 5121 """ 5122 The result of a call to aten._local_scalar_dense. 5123 """ 5124 5125 def get_reads(self): 5126 return () 5127 5128 def should_allocate(self): 5129 return False 5130 5131 def __init__(self, sym, keypath, data): 5132 data.realize() 5133 super().__init__(None, NoneLayout(torch.device("cpu")), self.unwrap_storage([data])) # type: ignore[arg-type] 5134 self.sym = sym 5135 self.keypath = keypath 5136 5137 def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: 5138 return {self.sym} 5139 5140 def codegen(self, wrapper): 5141 wrapper.codegen_dynamic_scalar(self) 5142 5143 5144class AssertScalar(ExternKernel): 5145 """ 5146 The result of a call to aten._assert_scalar 5147 """ 5148 5149 def get_reads(self): 5150 return () 5151 5152 def should_allocate(self): 5153 return False 5154 5155 def __init__(self, scalar, msg): 5156 super().__init__( 5157 # Buffer(name, layotu) 5158 None, 5159 NoneLayout(torch.device("cpu")), # type: ignore[arg-type] 5160 # InputsKernel(inputs) 5161 [], 5162 ) # type: ignore[arg-type] 5163 self.scalar = scalar 5164 self.msg = msg 5165 5166 def has_side_effects(self): 5167 return True 5168 5169 def get_unbacked_symbol_uses(self): 5170 return free_unbacked_symbols(self.scalar) 5171 5172 def codegen(self, wrapper): 5173 if V.graph.cpp_wrapper: 5174 pass 5175 else: 5176 # NB: It is EXTREMELY important not to simplify the scalar under 5177 # assertion here, because simplify is done with respect to 5178 # runtime asserts. So if you have "u0 == 0" in the runtime 5179 # asserts, if you subsequently try to simplify(u0 == 0), you will 5180 # get True (because we've already runtime assert'ed that it's 5181 # true). But we're code generating the actual runtime assert 5182 # here!! 5183 wrapper.writeline( 5184 f"if not {V.graph.wrapper_code.codegen_python_sizevar(self.scalar, simplify=False)}:" 5185 ) 5186 wrapper.writeline(f" raise RuntimeError({repr(self.msg)})") 5187 # No one should ever use this buffer, but for uniformity 5188 # define the variable and assign it None 5189 wrapper.writeline(f"{self.get_name()} = None") 5190 5191 5192@dataclasses.dataclass 5193class ExternKernelNode: 5194 name: str 5195 node: export_schema.Node 5196 5197 5198has_c_shim = { 5199 aten._embedding_bag.default, 5200 aten._fft_c2c.default, 5201 aten._scaled_dot_product_efficient_attention.default, 5202 aten._scaled_dot_product_flash_attention.default, 5203 aten._scaled_mm.default, 5204 aten.addmm.out, 5205 aten.bmm.out, 5206 aten.copy_.default, 5207 aten.mm.out, 5208 aten.repeat_interleave.Tensor, 5209 aten.nonzero.default, 5210 aten.view.dtype, 5211 aten.view_as_real.default, 5212} 5213 5214 5215class FallbackKernel(ExternKernelAlloc): 5216 def __init__( 5217 self, 5218 layout, 5219 kernel, 5220 tensor_args, 5221 nontensor_args, 5222 unflatten_args, 5223 kwargs=None, 5224 *, 5225 unbacked_bindings=None, 5226 ): 5227 if ( 5228 kernel == aten.mul.Tensor 5229 and len(tensor_args) == 1 5230 and len(nontensor_args) == 1 5231 ): 5232 # When aten.mul.Tensor's second arg is constant, cpp wrapper expects 5233 # to call mul_Scalar. A more proper fix is to do it in decomposition. 5234 # See https://github.com/pytorch/pytorch/issues/123478 5235 kernel = aten.mul.Scalar 5236 5237 super().__init__( 5238 layout, 5239 tuple(tensor_args), 5240 tuple(nontensor_args), 5241 op_overload=kernel, 5242 ) 5243 5244 # We need output buffers for generating kernel arguments in the 5245 # abi-compatible mode, where we retrieve outputs by pass each individual 5246 # output through the abi-compatible interface. 5247 self.outputs: Sequence[Any] = [] 5248 self.use_runtime_dispatch = False 5249 self.unbacked_bindings = unbacked_bindings 5250 5251 assert isinstance( 5252 kernel, 5253 ( 5254 torch._ops.OpOverload, 5255 torch._ops.HigherOrderOperator, 5256 ), 5257 ), f"Fails to create FallbackKernel for {kernel}: {type(kernel)} not supported" 5258 self.op_overload = kernel 5259 self.unflatten_args = unflatten_args 5260 self.kwargs = {} if kwargs is None else kwargs 5261 V.graph.warn_fallback(self.python_kernel_name) 5262 5263 # args that are aliased 5264 self.alias_names: List[str] = [] 5265 # args that are mutated AND returned from the op 5266 self.mutation_names: List[str] = [] 5267 5268 if isinstance(self.op_overload, torch._ops.HigherOrderOperator): 5269 # We assume here that HOPs with FallbackKernel are functional. 5270 # This may not always be true! HOPs must individually opt-in to 5271 # FallbackKernel, so please check this if you opt-in. 5272 return 5273 5274 if "_c10d_functional" in self.op_overload.name(): 5275 # _c10d_functional kernels are lowered into _CollectiveKernel which 5276 # derives from FallbackKernel for the cpp codegen. The kernels 5277 # don't pass the can_auto_functionalize check, but their mutation 5278 # is handled properly by _CollectiveKernel. 5279 return 5280 5281 schema = self.op_overload._schema 5282 5283 # NOTE: [FallbackKernel supported operators] 5284 # We only support three types of operators: 5285 # - functional ops 5286 # - view ops 5287 # - inplace aten ops 5288 # - mutating ops that are auto-functionalizable. That is, 5289 # the operator may mutate any number of inputs, but its outputs 5290 # may not alias any of the inputs. 5291 # 5292 # The unsupported cases usually do not show up here (because 5293 # AOTAutograd functionalized them away); the only way for an in-place 5294 # op to show up here is if a lowering or pass introduced it. 5295 if torch._library.utils.mutates_and_returns_first_arg(self.op_overload): 5296 self.mutation_names.append(tensor_args[0].get_name()) 5297 return 5298 5299 if schema.is_mutable and not can_auto_functionalize(kernel): 5300 raise NotImplementedError( 5301 f"NYI: Can't generate FallbackKernel for {kernel}" 5302 ) 5303 5304 schema_args = schema.arguments 5305 args, kwargs = self.unflatten_args(self.inputs, self.constant_args) 5306 5307 def handle_aliasing_and_mutation(info, arg): 5308 # Assertions to make sure we didn't mismatch args 5309 if isinstance(info.type, torch.ListType): 5310 assert isinstance(arg, (list, tuple)) 5311 is_optional_tensor = isinstance( 5312 info.type, torch.OptionalType 5313 ) and isinstance(info.type.getElementType(), torch.TensorType) 5314 if is_optional_tensor or isinstance(info.type, torch.TensorType): 5315 # PyTorch also accepts None and scalar types for args marked as "Tensor". 5316 # We're not going to check all of them here. 5317 assert not isinstance(arg, (tuple, list)) 5318 5319 if arg is None: 5320 return 5321 if info.alias_info is None: 5322 return 5323 # can_auto_functionalize already filters out mutable List[Tensor]. 5324 # We can support this in the future, but this is very uncommon. 5325 assert isinstance(info.type, torch.TensorType) or is_optional_tensor 5326 self.alias_names.append(arg.get_name()) 5327 if info.alias_info.is_write: 5328 mark_node_as_mutating(self, arg) 5329 5330 for info, arg in torch._library.utils.zip_schema(schema, args, kwargs): 5331 handle_aliasing_and_mutation(info, arg) 5332 5333 def codegen_unbacked_symbol_defs(self, wrapper): 5334 if not hasattr(self, "unbacked_bindings"): 5335 return 5336 5337 unbacked_bindings = resolve_unbacked_bindings( 5338 V.graph.sizevars.shape_env, self.unbacked_bindings 5339 ) 5340 5341 if not unbacked_bindings: 5342 return 5343 5344 for s, keypath in unbacked_bindings.items(): 5345 5346 def go(expr, keypath): 5347 if keypath == (): 5348 return expr 5349 5350 if ( 5351 len(keypath) >= 2 5352 and isinstance(keypath[0], CallMethodKey) 5353 and isinstance(keypath[1], pytree.SequenceKey) 5354 ): 5355 return go( 5356 f"{expr}.{keypath[0].name}({keypath[1].idx})", keypath[2:] 5357 ) 5358 elif isinstance(keypath[0], CallMethodKey): 5359 return go(f"{expr}.{keypath[0].name}()", keypath[1:]) 5360 elif isinstance(keypath[0], pytree.SequenceKey): 5361 return go(f"{expr}[{keypath[0].idx}]", keypath[1:]) 5362 elif isinstance(keypath[0], DivideByKey): 5363 # TODO: need to assert divisibility 5364 # TODO: this is invalid C++ codegen 5365 return go(f"{expr}.__floordiv__({keypath[0].divisor})", keypath[1:]) 5366 else: 5367 raise AssertionError(f"unrecognized keypath {keypath}") 5368 5369 def go_outer(): 5370 if V.graph.cpp_wrapper and config.abi_compatible: 5371 # Special handling for the top level buffer access, 5372 # because self.get_name() is actually never bound; the 5373 # individual output arguments are bound by 5374 # generate_c_shim_fallback_kernel 5375 if len(self.outputs) == 1: 5376 return go(self.outputs[0].get_name(), keypath) 5377 else: 5378 assert isinstance(keypath[0], pytree.SequenceKey) 5379 return go(self.outputs[keypath[0].idx].get_name(), keypath[1:]) 5380 else: 5381 return go(self.get_name(), keypath) 5382 5383 wrapper.writeline( 5384 f"{wrapper.codegen_unbacked_symbol_decl(s)} = {go_outer()}{wrapper.ending}" 5385 ) 5386 5387 def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: 5388 if unbacked_bindings := getattr(self, "unbacked_bindings", None): 5389 return resolve_unbacked_bindings( 5390 V.graph.sizevars.shape_env, unbacked_bindings 5391 ).keys() 5392 else: 5393 return set() 5394 5395 def set_cpp_kernel(self, kernel): 5396 from .codegen.wrapper import get_cpp_op_schema 5397 5398 assert ( 5399 not kernel._schema.is_mutable 5400 ), f"mutable {kernel.__name__} is not supported with cpp_wrapper" 5401 5402 # These checks are here because ops that return aliasing tensors will 5403 # return type Tensor& instead of Tensor, but codegen will always write 5404 # type Tensor on the LHS. 5405 def is_not_write(arg): 5406 return arg.alias_info is None or not arg.alias_info.is_write 5407 5408 assert all( 5409 is_not_write(x) for x in kernel._schema.arguments 5410 ), f"{kernel.__name__} with alias_info arguments is not supported with cpp_wrapper" 5411 assert all( 5412 is_not_write(x) for x in kernel._schema.returns 5413 ), f"{kernel.__name__} with alias_info returns is not supported with cpp_wrapper" 5414 5415 self.cpp_kernel_name = kernel._schema.name 5416 self.cpp_kernel_overload_name = kernel._schema.overload_name 5417 self.cpp_kernel_key = f"{self.cpp_kernel_name.replace('::', '_')}_{self.cpp_kernel_overload_name}" # type: ignore[union-attr] 5418 5419 self.cpp_op_schema = get_cpp_op_schema(kernel) 5420 5421 def codegen_args(self): 5422 @dataclasses.dataclass 5423 class Shim: 5424 ref: Any 5425 5426 def __repr__(self): 5427 return self.ref 5428 5429 tensor_args = [Shim(x.codegen_reference()) for x in self.inputs] 5430 args, kwargs = self.unflatten_args(tensor_args, self.constant_args) 5431 if V.graph.cpp_wrapper and isinstance(self.op_overload, torch._ops.OpOverload): 5432 args = self.fill_non_provided_args(args, kwargs) 5433 args = [ 5434 V.graph.wrapper_code.val_to_arg_str(x, param.real_type) 5435 for param, x in zip(self.op_overload._schema.arguments, args) 5436 ] 5437 else: 5438 args = [V.graph.wrapper_code.val_to_arg_str(x) for x in args] 5439 5440 # let self.codegen_kwargs handle kwargs 5441 self.kwargs.update(kwargs) 5442 return args 5443 5444 @staticmethod 5445 def find_device(tensor_args, example_output): 5446 if tensor_args: 5447 devices = [arg.get_device() for arg in tensor_args if arg.get_device()] 5448 return devices[0] 5449 if isinstance(example_output, torch.Tensor): 5450 return example_output.device 5451 if isinstance(example_output, (list, tuple)): 5452 device_set = {FallbackKernel.find_device(None, x) for x in example_output} 5453 # Remove None 5454 devices = [device for device in device_set if device] 5455 if len(devices) == 1: 5456 return devices[0] 5457 for device in devices: 5458 if is_gpu(device.type): 5459 return device 5460 return devices[0] 5461 return None 5462 5463 def has_side_effects(self): 5464 if isinstance(self.op_overload, torch._ops.HigherOrderOperator): 5465 return False 5466 return get_schema_info(self.op_overload).is_mutable() 5467 5468 def get_inputs_that_alias_output(self): 5469 return self.alias_names 5470 5471 def get_mutation_names(self): 5472 assert len(self.mutation_names) <= 1 5473 return self.mutation_names 5474 5475 # ProxyExecutor Design Note 5476 # We export the ExternFallbackNodes (for custom ops) into a serialized file 5477 # and run it with a host side proxy executor to address the ABI problem 5478 # This is currently only implemented for fbcode. Eventually, we will also make this work for OSS. 5479 # Detailed design doc can be found at 5480 # https://docs.google.com/document/d/1wC4DOZFaYym2t1Esz0X5yxlLI3RDnSiyRbUus3bkJ64/edit?usp=sharing 5481 def export_extern_kernel_node(self): 5482 assert isinstance(self, FallbackKernel) 5483 args, kwargs = self.unflatten_args(self.inputs, self.constant_args) 5484 args = self.fill_non_provided_args(args, kwargs) 5485 ordered_kwargs = [ 5486 kwargs.get(key, None) for key in self.ordered_kwargs_for_cpp_kernel 5487 ] 5488 if not V.graph.aot_mode: 5489 # No need to serialize in the cpp wrapper JIT mode 5490 return [*args, *ordered_kwargs] 5491 5492 serializer = GraphModuleSerializer(None, None) # type: ignore[arg-type] 5493 named_arguments = serializer.serialize_inputs(self.op_overload, args, kwargs) # type: ignore[arg-type] 5494 5495 # serialize_outputs 5496 def handle_single_output(return_type, output): 5497 if isinstance(return_type, torch.TensorType): 5498 # For single Tensor 5499 out = output 5500 if isinstance(output, (list, tuple)): 5501 assert len(output) == 1 5502 out = output[0] 5503 return export_schema.Argument.create( 5504 as_tensor=export_schema.TensorArgument(name=out.get_name()) 5505 ) 5506 elif isinstance(return_type, torch.ListType) and isinstance( 5507 return_type.getElementType(), torch.TensorType 5508 ): 5509 # For single TensorList 5510 return export_schema.Argument.create( 5511 as_tensors=[ 5512 export_schema.TensorArgument(name=out.get_name()) 5513 for out in output 5514 ] 5515 ) 5516 else: 5517 raise RuntimeError(f"Unsupported return type {type(return_type)}") 5518 5519 target = self.op_overload 5520 returns = target._schema.returns # type: ignore[union-attr] 5521 if len(returns) == 1: 5522 return_type = returns[0].real_type 5523 output_arguments = [handle_single_output(return_type, self.outputs)] 5524 else: 5525 # For tuple returns, e.g "-> (Tensor, Tensor)" or "-> (Tesnor, Tensor[])" 5526 assert isinstance(self.outputs, tuple) 5527 assert len(returns) == len(self.outputs) 5528 output_arguments = [ 5529 handle_single_output(return_schema.real_type, output) 5530 for return_schema, output in zip(returns, self.outputs) 5531 ] 5532 5533 node = ExternKernelNode( 5534 name=self.get_name(), 5535 node=export_schema.Node( 5536 target=self.op_overload.name(), # type: ignore[union-attr] 5537 inputs=named_arguments, 5538 outputs=output_arguments, 5539 metadata={}, 5540 ), 5541 ) 5542 5543 V.graph.extern_kernel_nodes.append(node) 5544 5545 return [*args, *ordered_kwargs] 5546 5547 def codegen(self, wrapper): 5548 kernel = self.op_overload 5549 if kernel.namespace == "aten": # type: ignore[union-attr] 5550 # Aten Fallback Ops 5551 assert isinstance(kernel, torch._ops.OpOverload) 5552 if V.graph.cpp_wrapper: 5553 if ( 5554 config.is_fbcode() 5555 and kernel not in has_c_shim 5556 # C shim v2 is torchgen-ed, which should cover all aten ops. 5557 # If you do hit a missed op, please update gen_aoti_c_shim.py. 5558 and config.c_shim_version == "1" 5559 ): 5560 log.warning( 5561 "%s is missing a c-shim implementation, using proxy executor as fallback", 5562 kernel, 5563 ) 5564 self.use_runtime_dispatch = True 5565 self.set_cpp_kernel(kernel) 5566 else: 5567 self.python_kernel_name = str(kernel) 5568 elif kernel.namespace == "_quantized": # type: ignore[union-attr] 5569 # Internal Quantized Fallback Ops 5570 assert isinstance(kernel, torch._ops.OpOverload) 5571 if V.graph.cpp_wrapper: 5572 self.set_cpp_kernel(kernel) 5573 if not config.abi_compatible: 5574 self.use_runtime_dispatch = True 5575 else: 5576 self.python_kernel_name = str(kernel) 5577 elif isinstance(kernel, torch._ops.HigherOrderOperator): 5578 self.python_kernel_name = f"torch.ops.higher_order.{kernel.__name__}" 5579 else: 5580 # For non-aten OpOverload, i.e. custom ops 5581 self.python_kernel_name = f"{kernel.__module__.replace('._ops.', '.ops.')}.{kernel.__name__}" # type: ignore[union-attr] 5582 if V.graph.cpp_wrapper: 5583 self.use_runtime_dispatch = True 5584 self.set_cpp_kernel(kernel) 5585 5586 if self.use_runtime_dispatch: 5587 self.codegen_comment(wrapper) 5588 5589 exported_args = None 5590 args = None 5591 if config.abi_compatible: 5592 exported_args = self.export_extern_kernel_node() 5593 else: 5594 args = [*self.codegen_args(), *self.codegen_kwargs()] 5595 5596 wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( 5597 self.get_name(), 5598 self.python_kernel_name, 5599 self.cpp_kernel_name, 5600 args, 5601 self.cpp_op_schema, 5602 self.cpp_kernel_key, 5603 self.cpp_kernel_overload_name, 5604 self.op_overload, 5605 exported_args, 5606 self.outputs, 5607 ) 5608 else: 5609 self.codegen_comment(wrapper) 5610 args = [*self.codegen_args(), *self.codegen_kwargs()] 5611 V.graph.wrapper_code.generate_fallback_kernel(self, args) 5612 if isinstance(self.layout, Layout): 5613 self.codegen_size_asserts(wrapper) 5614 5615 self.codegen_unbacked_symbol_defs(wrapper) 5616 5617 @staticmethod 5618 def tensor_to_layout(output: torch.Tensor): 5619 return FixedLayout( 5620 output.device, 5621 output.dtype, 5622 convert_shape_to_inductor(output.size()), 5623 convert_shape_to_inductor(output.stride()), 5624 ) 5625 5626 @classmethod 5627 def create(cls, kernel, *args, **kwargs): 5628 fake_incorrect_kernels = (aten._fused_moving_avg_obs_fq_helper_functional,) 5629 context = ( 5630 V.graph.fake_mode if kernel not in fake_incorrect_kernels else nullcontext() 5631 ) 5632 with context: 5633 ( 5634 example_output, 5635 tensor_args, 5636 non_tensor_args, 5637 unflatten_args, 5638 unbacked_bindings, 5639 ) = cls.process_kernel(kernel, *args, **kwargs) 5640 5641 device = cls.find_device(tensor_args, example_output) 5642 if example_output is None: 5643 packed = cls( 5644 NoneLayout(device), 5645 kernel, 5646 tensor_args, 5647 non_tensor_args, 5648 unflatten_args, 5649 unbacked_bindings=unbacked_bindings, 5650 ) 5651 5652 else: 5653 assert device, "Not sure where to find device info" 5654 packed = cls( 5655 MultiOutputLayout(device), 5656 kernel, 5657 tensor_args, 5658 non_tensor_args, 5659 unflatten_args, 5660 unbacked_bindings=unbacked_bindings, 5661 ) 5662 5663 def generate_output(output, indices): 5664 if isinstance(output, (list, tuple)): 5665 return type(output)( 5666 generate_output(output[i], indices + [(type(output), i)]) 5667 for i in range(len(output)) 5668 ) 5669 elif isinstance(output, dict): 5670 return { 5671 key: generate_output(val, indices + [(type(output), key)]) 5672 for key, val in output.items() 5673 } 5674 elif isinstance(output, torch.Tensor): 5675 return MultiOutput( 5676 cls.tensor_to_layout(output), 5677 packed, 5678 indices, 5679 ) 5680 elif isinstance(output, int): 5681 return output 5682 elif isinstance(output, torch.SymInt): 5683 return output.node.expr 5684 else: 5685 assert ( 5686 output is None 5687 ), f"FallbackKernel output type {type(output)} is not supported" 5688 return None 5689 5690 outputs = generate_output(example_output, []) 5691 if isinstance(outputs, (list, tuple, dict)): 5692 packed.outputs = outputs # type: ignore[assignment] 5693 else: 5694 packed.outputs = [outputs] 5695 return outputs 5696 5697 def apply_constraint(self): 5698 return super().apply_constraint() 5699 5700 5701@dataclasses.dataclass 5702class ComplexView(FallbackKernel): 5703 """View a complex number as two dtyped numbers or vice versa""" 5704 5705 def should_allocate(self): 5706 return False 5707 5708 def get_inputs_that_alias_output(self): 5709 # Signal to codegen that our output buffer isn't safe to reuse 5710 return [self.inputs[0].get_name()] 5711 5712 def __init__( 5713 self, 5714 layout, 5715 kernel, 5716 tensor_args, 5717 nontensor_args, 5718 unflatten_args, 5719 *, 5720 unbacked_bindings=None, 5721 ): 5722 super().__init__( 5723 layout, 5724 kernel, 5725 tensor_args, 5726 nontensor_args, 5727 unflatten_args, 5728 unbacked_bindings=unbacked_bindings, 5729 ) 5730 5731 5732@dataclasses.dataclass 5733class MultiOutputLayout(IRNode): 5734 device: torch.device 5735 5736 5737class MultiOutput(ExternKernel): 5738 # Given an input MultiOutputLayout buffer, indexes out an actual buffer 5739 # from that result. This doesn't actually produce multiple outputs, 5740 # that's MultiOutputLayout! 5741 def codegen_list_tuple_access(self, basename, indices): 5742 if len(indices) > 0: 5743 itype, i = indices[0] 5744 if issubclass(itype, list): 5745 return self.codegen_list_tuple_access(f"{basename}[{i}]", indices[1:]) 5746 elif issubclass(itype, tuple): 5747 # cpp wrapper code needs to use std::get<> to access a tuple 5748 tuple_access = V.graph.wrapper_code.codegen_tuple_access( 5749 basename, self.get_name(), str(i) 5750 ) 5751 return self.codegen_list_tuple_access(tuple_access, indices[1:]) 5752 elif issubclass(itype, dict): 5753 return self.codegen_list_tuple_access(f"{basename}['{i}']", indices[1:]) 5754 else: 5755 raise AssertionError("non supported index type: ", itype) 5756 else: 5757 return basename 5758 5759 def codegen(self, wrapper): 5760 wrapper.codegen_multi_output( 5761 self.get_name(), 5762 self.codegen_list_tuple_access(self.inputs[0].get_name(), self.indices), 5763 ) 5764 5765 def __init__(self, layout, input, indices: List[Tuple[Any, ...]]): 5766 super().__init__(None, layout, [input], ()) 5767 self.name = V.graph.register_buffer(self) 5768 self.indices = indices 5769 5770 def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: 5771 return self.inputs[0].get_unbacked_symbol_uses() 5772 5773 def should_allocate(self): 5774 return False 5775 5776 def get_inputs_that_alias_output(self): 5777 return [ 5778 inp.get_name() 5779 for inp in self.inputs 5780 if isinstance(inp, FallbackKernel) 5781 and len(inp.get_inputs_that_alias_output()) > 0 5782 ] 5783 5784 5785def _prepare_convolution_fusion_create( 5786 cls, 5787 x: "TensorBox", 5788 weight: "TensorBox", 5789 bias: "TensorBox", 5790 padding: List[int], 5791 stride: List[int], 5792 dilation: List[int], 5793 groups: int, 5794 transposed: bool = False, 5795 output_padding: Optional[List[int]] = None, 5796): 5797 """ 5798 This function is a helper function to prepare inputs, layout and constant args 5799 for convolution post-op fusion's create function, including deciding the output 5800 layout (channels first or channels last), realizing inputs and make them etc. The 5801 function only supports the CPU device since conv post-op fusion kernel is only 5802 supported on CPU right now. 5803 """ 5804 5805 # Port from aten/src/ATen/native/ConvUtils.h: _conv_input_size 5806 def _conv_input_size( 5807 output_size, weight_size, padding, output_padding, stride, dilation, groups 5808 ): 5809 assert len(output_size) == len(weight_size), "Expect input dim == weight dim" 5810 dim = len(output_size) 5811 assert dim > 2, "Expect input dim > 2" 5812 5813 BATCH_DIM = 0 5814 WEIGHT_INPUT_CHANNELS_DIM = 1 5815 input_size = [] 5816 input_size.append(output_size[BATCH_DIM]) 5817 input_size.append(weight_size[WEIGHT_INPUT_CHANNELS_DIM] * groups) 5818 for d in range(2, dim): 5819 kernel = (weight_size[d] - 1) * dilation[d - 2] + 1 5820 input_size_d = ( 5821 (output_size[d] - 1) * stride[d - 2] 5822 - (padding[d - 2] * 2) 5823 + kernel 5824 + output_padding[d - 2] 5825 ) 5826 input_size.append(input_size_d) 5827 return list(map(int, input_size)) 5828 5829 # The size of prepacked_weight is the prepacked weight size of deconv: 5830 # Groups > 1: [g*o, i/g, ...] 5831 # Groups == 1: [o, i, ...] 5832 # Returns original weight size in [i, o, ...] 5833 def _original_deconv_weight_size( 5834 prepacked_weight, 5835 groups, 5836 ): 5837 prepacked_weight_size = prepacked_weight.size() 5838 dim = len(prepacked_weight_size) 5839 assert dim > 2, "Expect weight dim > 2" 5840 if groups > 1: 5841 weight_size = [] 5842 weight_size.append(prepacked_weight_size[1] * groups) 5843 weight_size.append(prepacked_weight_size[0] / groups) 5844 for d in range(2, dim): 5845 weight_size.append(prepacked_weight_size[d]) 5846 else: 5847 weight_size = prepacked_weight.transpose(0, 1).size() 5848 return weight_size 5849 5850 x.realize() 5851 weight.realize() 5852 if bias is not None: 5853 bias.realize() 5854 with V.graph.fake_mode: 5855 # TODO <Leslie> cleaned up the fake_tensor trace as Linear implementation 5856 x_fake = ir_node_to_tensor(x, guard_shape=True) 5857 weight_fake = ir_node_to_tensor(weight, guard_shape=True) 5858 dims = len(x_fake.size()) - 2 5859 assert 0 < len(padding) <= dims 5860 assert 0 < len(dilation) <= dims 5861 assert 0 < len(stride) <= dims 5862 padding = pad_listlike(padding, dims) 5863 dilation = pad_listlike(dilation, dims) 5864 stride = pad_listlike(stride, dims) 5865 if output_padding is None: 5866 output_padding = pad_listlike([0], dims) 5867 else: 5868 assert 0 < len(output_padding) <= dims 5869 output_padding = pad_listlike(output_padding, dims) 5870 assert isinstance(groups, int) 5871 if transposed: 5872 # When transposed, the size of the prepacked oneDNN weight is different 5873 # from the PyTorch weight. We're not able to run aten conv with such 5874 # size. We infer the output size from the input params here: 5875 weight_size = _original_deconv_weight_size(weight_fake, groups) 5876 input_size = x_fake.size() 5877 output_size = _conv_input_size( 5878 input_size, 5879 weight_size, 5880 padding, 5881 output_padding, 5882 stride, 5883 dilation, 5884 groups, 5885 ) 5886 else: 5887 bias_fake = ( 5888 ir_node_to_tensor(bias, guard_shape=True) if bias is not None else bias 5889 ) 5890 output = torch.ops.aten.convolution( 5891 x_fake, 5892 weight_fake, 5893 bias_fake, 5894 stride, 5895 padding, 5896 dilation, 5897 transposed, 5898 output_padding, 5899 groups, 5900 ) 5901 output_size = output.size() 5902 5903 req_stride_order = [0] + list(reversed(range(1, len(stride) + 1))) 5904 req_stride_order = [len(req_stride_order)] + req_stride_order 5905 5906 x = cls.require_stride_order(x, req_stride_order) 5907 5908 # We won't do weight prepack for Conv if dynamic_shapes. 5909 # In static shape cases, since weight is prepacked, we'll always force output to be channels last in the Conv kernel. 5910 # In dynamic shape cases, for input with channels = 1, like tensor of size (s0, 1, 28, 28) and stride (784, 784, 28, 1), 5911 # x = cls.require_stride_order(x, req_stride_order) where req_stride_order is in the channels last order 5912 # won't change the stride of this tensor since stride for dimensions of size 1 is ignored. While in Conv kernel, 5913 # this tensor is considered as channels first and the output will be in contiguous format. 5914 # To align the behavior of the Conv kernel, we set the output_stride in such case to be contiguous instead of channels last. 5915 dynamic_shapes = not all(isinstance(i, int) for i in (output_size)) 5916 if dynamic_shapes and is_contiguous_storage_and_layout(x): 5917 output_stride = FlexibleLayout.contiguous_strides(output_size) 5918 else: 5919 output_stride = make_channels_last_strides_for(output_size) 5920 5921 assert x.get_device().type == "cpu" and weight.get_device().type == "cpu" 5922 inputs = [x, weight] 5923 5924 kernel_layout = FixedLayout( 5925 x.get_device(), 5926 x.get_dtype(), 5927 convert_shape_to_inductor(output_size), 5928 convert_shape_to_inductor(output_stride), 5929 ) 5930 constant_args = [padding, stride, dilation, groups] 5931 if transposed: 5932 constant_args.insert(1, output_padding) 5933 5934 if bias is not None: 5935 inputs.append(bias) 5936 else: 5937 constant_args.insert(0, bias) 5938 return inputs, constant_args, kernel_layout, req_stride_order 5939 5940 5941def _prepare_linear_fusion_create( 5942 cls, 5943 x: "TensorBox", 5944 weight: "TensorBox", 5945 bias: "TensorBox", 5946): 5947 """ 5948 This function is a helper function to prepare inputs, layout and constant args 5949 for linear post-op fusion's create function. The function only supports the CPU device 5950 since linear post-op fusion kernel is only supported on CPU right now. 5951 """ 5952 x.realize() 5953 weight.realize() 5954 if bias is not None: 5955 bias.realize() 5956 5957 *m, _ = x.get_size() 5958 # The weight has been transposed during the qlinear weight prepack process. 5959 # https://github.com/pytorch/pytorch/blob/4979f9c0d72490970e2019bb1d2284f83d93f76b/ 5960 # aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp#L291 5961 _, oc = weight.get_size() 5962 output_size = list(m) + [oc] 5963 req_stride_order = list(reversed(range(len(x.get_size())))) 5964 5965 x = cls.require_stride_order(x, req_stride_order) 5966 assert x.get_device().type == "cpu" and weight.get_device().type == "cpu" 5967 inputs = [x, weight] 5968 5969 output_stride = FlexibleLayout.contiguous_strides(output_size) 5970 kernel_layout = FixedLayout( 5971 x.get_device(), 5972 x.get_dtype(), 5973 output_size, 5974 output_stride, 5975 ) 5976 constant_args: List[Any] = [] 5977 5978 if bias is not None: 5979 inputs.append(bias) 5980 else: 5981 constant_args.insert(0, bias) 5982 return inputs, constant_args, kernel_layout, req_stride_order 5983 5984 5985class ConvolutionUnary(ExternKernelAlloc): 5986 def __init__( 5987 self, 5988 layout, 5989 inputs, 5990 constant_args=(), 5991 ): 5992 super().__init__( 5993 layout, 5994 inputs, 5995 constant_args, 5996 None, 5997 python_kernel_name="torch.ops.mkldnn._convolution_pointwise", 5998 cpp_kernel_name="mkldnn::_convolution_pointwise", 5999 ) 6000 self.cpp_kernel_key = "convolution_pointwise" 6001 self.cpp_op_schema = """ 6002 at::Tensor( 6003 const at::Tensor& input_t, 6004 const at::Tensor& weight_t, 6005 const c10::optional<at::Tensor>& bias_opt, 6006 at::IntArrayRef padding, 6007 at::IntArrayRef stride, 6008 at::IntArrayRef dilation, 6009 int64_t groups, 6010 c10::string_view attr, 6011 torch::List<c10::optional<at::Scalar>> scalars, 6012 c10::optional<c10::string_view> algorithm)""" 6013 6014 def codegen(self, wrapper): 6015 wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( 6016 self.get_name(), 6017 self.python_kernel_name, 6018 self.cpp_kernel_name, 6019 self.codegen_args(), 6020 self.cpp_op_schema, 6021 self.cpp_kernel_key, 6022 ) 6023 if isinstance(self.layout, Layout): 6024 self.codegen_size_asserts(wrapper) 6025 6026 @classmethod 6027 def create( 6028 cls, 6029 x: "TensorBox", 6030 weight: "TensorBox", 6031 bias: "TensorBox", 6032 padding_: List[int], 6033 stride_: List[int], 6034 dilation_: List[int], 6035 groups: int, 6036 attr, 6037 scalars: Optional[List[Any]], 6038 algorithm, 6039 ): 6040 (inputs, constant_args, kernel_layout, _) = _prepare_convolution_fusion_create( 6041 cls, x, weight, bias, padding_, stride_, dilation_, groups 6042 ) 6043 constant_args = constant_args + [ 6044 attr, 6045 may_convert_to_optional(scalars), 6046 algorithm, 6047 ] 6048 return ConvolutionUnary( 6049 layout=kernel_layout, 6050 inputs=inputs, 6051 constant_args=constant_args, 6052 ) 6053 6054 6055class ConvolutionBinary(ExternKernelAlloc): 6056 def __init__( 6057 self, 6058 layout, 6059 inputs, 6060 constant_args=(), 6061 cpp_constant_args=(), 6062 ): 6063 super().__init__( 6064 layout, 6065 inputs, 6066 constant_args, 6067 None, 6068 python_kernel_name="torch.ops.mkldnn._convolution_pointwise.binary", 6069 cpp_kernel_name="mkldnn::_convolution_pointwise", 6070 ) 6071 self.cpp_kernel_overload_name = "binary" 6072 self.cpp_kernel_key = "convolution_pointwise_binary" 6073 self.cpp_op_schema = """ 6074 at::Tensor( 6075 const at::Tensor& input_t, 6076 const at::Tensor& other_t, 6077 const at::Tensor& weight_t, 6078 const c10::optional<at::Tensor>& bias_opt, 6079 at::IntArrayRef padding, 6080 at::IntArrayRef stride, 6081 at::IntArrayRef dilation, 6082 int64_t groups, 6083 c10::string_view binary_attr, 6084 c10::optional<at::Scalar> alpha, 6085 c10::optional<c10::string_view> unary_attr, 6086 torch::List<c10::optional<at::Scalar>> unary_scalars, 6087 c10::optional<c10::string_view> unary_algorithm)""" 6088 self.cpp_constant_args = cpp_constant_args 6089 6090 def codegen(self, wrapper): 6091 wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( 6092 self.get_name(), 6093 self.python_kernel_name, 6094 self.cpp_kernel_name, 6095 self.codegen_args(), 6096 self.cpp_op_schema, 6097 self.cpp_kernel_key, 6098 self.cpp_kernel_overload_name, 6099 ) 6100 if isinstance(self.layout, Layout): 6101 self.codegen_size_asserts(wrapper) 6102 6103 @classmethod 6104 def create( 6105 cls, 6106 x: "TensorBox", 6107 other: "TensorBox", 6108 weight: "TensorBox", 6109 bias: "TensorBox", 6110 padding_: List[int], 6111 stride_: List[int], 6112 dilation_: List[int], 6113 groups: int, 6114 binary_attr: str, 6115 binary_alpha: Optional[float], 6116 unary_attr: Optional[str], 6117 unary_scalars: Optional[List[Any]], 6118 unary_algorithm: Optional[str], 6119 ): 6120 ( 6121 inputs, 6122 constant_args, 6123 kernel_layout, 6124 req_stride_order, 6125 ) = _prepare_convolution_fusion_create( 6126 cls, x, weight, bias, padding_, stride_, dilation_, groups 6127 ) 6128 other = cls.require_stride_order(other, req_stride_order) 6129 inputs.insert(1, other) 6130 constant_args = constant_args + [ 6131 binary_attr, 6132 binary_alpha, 6133 unary_attr, 6134 may_convert_to_optional(unary_scalars), 6135 unary_algorithm, 6136 ] 6137 return ConvolutionBinary( 6138 layout=kernel_layout, 6139 inputs=inputs, 6140 constant_args=constant_args, 6141 ) 6142 6143 6144class ConvolutionBinaryInplace(ExternKernelAlloc): 6145 def __init__( 6146 self, 6147 kernel_layout, 6148 inputs, 6149 constant_args=(), 6150 ): 6151 # Due to constrain of op.call, other (Tensor&) should be at input[0] 6152 reordered_inputs = [inputs[1], inputs[0]] + inputs[2:] 6153 6154 super().__init__( 6155 kernel_layout, 6156 reordered_inputs, 6157 constant_args, 6158 None, 6159 python_kernel_name="torch.ops.mkldnn._convolution_pointwise_.binary", 6160 cpp_kernel_name="mkldnn::_convolution_pointwise_", 6161 ) 6162 self.cpp_kernel_overload_name = "binary" 6163 self.cpp_kernel_key = "convolution_pointwise_binary_" 6164 # TODO: op.call: input[0] should be at::Tensor& 6165 self.cpp_op_schema = """ 6166 at::Tensor&( 6167 at::Tensor& other_t, 6168 const at::Tensor& input_t, 6169 const at::Tensor& weight_t, 6170 const c10::optional<at::Tensor>& bias_opt, 6171 at::IntArrayRef padding, 6172 at::IntArrayRef stride, 6173 at::IntArrayRef dilation, 6174 int64_t groups, 6175 c10::string_view binary_attr, 6176 c10::optional<at::Scalar> alpha, 6177 c10::optional<c10::string_view> unary_attr, 6178 torch::List<c10::optional<at::Scalar>> unary_scalars, 6179 c10::optional<c10::string_view> unary_algorithm)""" 6180 6181 def codegen(self, wrapper): 6182 wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( 6183 self.get_name(), 6184 self.python_kernel_name, 6185 self.cpp_kernel_name, 6186 self.codegen_args(), 6187 self.cpp_op_schema, 6188 self.cpp_kernel_key, 6189 self.cpp_kernel_overload_name, 6190 ) 6191 6192 def get_mutation_names(self): 6193 return [self.inputs[0].get_name()] 6194 6195 def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: 6196 return set() 6197 6198 @classmethod 6199 def create( 6200 cls, 6201 x: "TensorBox", 6202 other: "TensorBox", 6203 weight: "TensorBox", 6204 bias: "TensorBox", 6205 padding_: List[int], 6206 stride_: List[int], 6207 dilation_: List[int], 6208 groups: int, 6209 binary_attr: str, 6210 binary_alpha: Optional[float], 6211 unary_attr: Optional[str], 6212 unary_scalars: Optional[List[Any]], 6213 unary_algorithm: Optional[str], 6214 ): 6215 ( 6216 inputs, 6217 constant_args, 6218 _, 6219 req_stride_order, 6220 ) = _prepare_convolution_fusion_create( 6221 cls, x, weight, bias, padding_, stride_, dilation_, groups 6222 ) 6223 other = cls.require_stride_order(other, req_stride_order) 6224 inputs.insert(1, other) 6225 constant_args = constant_args + [ 6226 binary_attr, 6227 binary_alpha, 6228 unary_attr, 6229 may_convert_to_optional(unary_scalars), 6230 unary_algorithm, 6231 ] 6232 packed = ConvolutionBinaryInplace( 6233 kernel_layout=NoneLayout(inputs[1].get_device()), # type: ignore[arg-type] 6234 inputs=inputs, 6235 constant_args=constant_args, 6236 ) 6237 mark_node_as_mutating(packed, inputs[1]) 6238 # This op mutates in place which means that the result is not the 6239 # target but rather the input that is being mutated 6240 # init reorders the inputs, so inputs[1] becomes packed.inputs[0] 6241 return packed.inputs[0] 6242 6243 6244class MKLPackedLinear(ExternKernelAlloc): 6245 def __init__( 6246 self, 6247 layout, 6248 inputs, 6249 constant_args=(), 6250 ): 6251 super().__init__( 6252 layout, 6253 inputs, 6254 constant_args, 6255 None, 6256 python_kernel_name="torch.ops.mkl._mkl_linear", 6257 cpp_kernel_name="mkl::_mkl_linear", 6258 ) 6259 self.cpp_kernel_key = "mkl_linear" 6260 self.cpp_op_schema = """ 6261 at::Tensor( 6262 const at::Tensor& self, 6263 const at::Tensor& mkl_weight_t, 6264 const at::Tensor& origin_weight_t, 6265 const c10::optional<at::Tensor>& bias_opt, 6266 const int64_t prepack_batch_size)""" 6267 6268 def codegen(self, wrapper): 6269 wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( 6270 self.get_name(), 6271 self.python_kernel_name, 6272 self.cpp_kernel_name, 6273 self.codegen_args(), 6274 self.cpp_op_schema, 6275 self.cpp_kernel_key, 6276 ) 6277 6278 @classmethod 6279 def create(cls, x, packed_w, orig_w, B, batch_size): 6280 x = cls.require_stride1(cls.realize_input(x)) 6281 orig_w = cls.require_stride1(cls.realize_input(orig_w)) 6282 *m, _ = x.get_size() 6283 oc, _ = orig_w.get_size() 6284 output_size = list(m) + [oc] 6285 output_stride = FlexibleLayout.contiguous_strides(output_size) 6286 inputs = [x, packed_w, orig_w] 6287 constant_args = [batch_size] 6288 if B is not None: 6289 inputs += [B] 6290 else: 6291 constant_args.insert(0, None) 6292 6293 return MKLPackedLinear( 6294 layout=FixedLayout( 6295 x.get_device(), x.get_dtype(), output_size, output_stride 6296 ), 6297 inputs=inputs, 6298 constant_args=constant_args, 6299 ) 6300 6301 6302class LinearUnary(ExternKernelAlloc): 6303 def __init__( 6304 self, 6305 layout, 6306 inputs, 6307 constant_args=(), 6308 ): 6309 super().__init__( 6310 layout, 6311 inputs, 6312 constant_args, 6313 None, 6314 python_kernel_name="torch.ops.mkldnn._linear_pointwise", 6315 cpp_kernel_name="mkldnn::_linear_pointwise", 6316 ) 6317 self.cpp_kernel_key = "linear_pointwise" 6318 self.cpp_op_schema = """ 6319 at::Tensor( 6320 const at::Tensor& input_t, 6321 const at::Tensor& weight_t, 6322 const c10::optional<at::Tensor>& bias_opt, 6323 c10::string_view attr, 6324 torch::List<c10::optional<at::Scalar>> scalars, 6325 c10::optional<c10::string_view> algorithm)""" 6326 6327 def codegen(self, wrapper): 6328 wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( 6329 self.get_name(), 6330 self.python_kernel_name, 6331 self.cpp_kernel_name, 6332 self.codegen_args(), 6333 self.cpp_op_schema, 6334 self.cpp_kernel_key, 6335 ) 6336 6337 @classmethod 6338 def create(cls, x, w, b, attr, scalars, algorithm): 6339 x = cls.require_contiguous(cls.realize_input(x)) 6340 w = cls.require_contiguous(cls.realize_input(w)) 6341 6342 *m, ic = x.get_size() 6343 oc, ic = w.get_size() 6344 inputs = [x, w] 6345 constant_args = [attr, scalars if scalars else [-1], algorithm] 6346 if b is not None: 6347 b = cls.require_contiguous(cls.realize_input(b)) 6348 inputs.append(b) 6349 else: 6350 constant_args.insert(0, None) 6351 6352 return LinearUnary( 6353 layout=FlexibleLayout( 6354 device=x.get_device(), 6355 dtype=x.get_dtype(), 6356 size=list(m) + [oc], 6357 ), 6358 inputs=inputs, 6359 constant_args=constant_args, 6360 ) 6361 6362 def apply_constraint(self): 6363 pass 6364 6365 6366class LinearBinary(ExternKernelAlloc): 6367 kernel = "torch.ops.mkldnn._linear_pointwise.binary" 6368 6369 def __init__( 6370 self, 6371 layout, 6372 inputs, 6373 constant_args=(), 6374 ): 6375 super().__init__( 6376 layout, 6377 inputs, 6378 constant_args, 6379 None, 6380 python_kernel_name="torch.ops.mkldnn._linear_pointwise.binary", 6381 cpp_kernel_name="mkldnn::_linear_pointwise", 6382 ) 6383 self.cpp_kernel_overload_name = "binary" 6384 self.cpp_kernel_key = "linear_pointwise_binary" 6385 self.cpp_op_schema = """ 6386 at::Tensor( 6387 const at::Tensor& input_t, 6388 const at::Tensor& other_t, 6389 const at::Tensor& weight_t, 6390 const c10::optional<at::Tensor>& bias_opt, 6391 c10::string_view attr) 6392 """ 6393 6394 def codegen(self, wrapper): 6395 wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( 6396 self.get_name(), 6397 self.python_kernel_name, 6398 self.cpp_kernel_name, 6399 self.codegen_args(), 6400 self.cpp_op_schema, 6401 self.cpp_kernel_key, 6402 self.cpp_kernel_overload_name, 6403 ) 6404 6405 @classmethod 6406 def create(cls, x, y, w, B, attr): 6407 x = cls.require_contiguous(cls.realize_input(x)) 6408 y = cls.require_contiguous(cls.realize_input(y)) 6409 w = cls.require_contiguous(cls.realize_input(w)) 6410 6411 *m, ic = x.get_size() 6412 oc, ic = w.get_size() 6413 6414 inputs = [x, y, w] 6415 constant_args = [attr] 6416 if B is not None: 6417 B = cls.require_contiguous(cls.realize_input(B)) 6418 inputs.append(B) 6419 else: 6420 constant_args.insert(0, B) 6421 6422 return LinearBinary( 6423 layout=FlexibleLayout( 6424 device=x.get_device(), 6425 dtype=x.get_dtype(), 6426 size=list(m) + [oc], 6427 ), 6428 inputs=inputs, 6429 constant_args=constant_args, 6430 ) 6431 6432 def apply_constraint(self): 6433 pass 6434 6435 6436class ConvolutionTransposeUnary(ExternKernelAlloc): 6437 def __init__( 6438 self, 6439 layout, 6440 inputs, 6441 constant_args=(), 6442 ): 6443 super().__init__( 6444 layout, 6445 inputs, 6446 constant_args, 6447 None, 6448 python_kernel_name="torch.ops.mkldnn._convolution_transpose_pointwise", 6449 cpp_kernel_name="mkldnn::_convolution_transpose_pointwise", 6450 ) 6451 self.cpp_kernel_key = "convolution_transpose_pointwise" 6452 self.cpp_op_schema = """ 6453 at::Tensor( 6454 const at::Tensor& input_t, 6455 const at::Tensor& weight_t, 6456 const c10::optional<at::Tensor>& bias_opt, 6457 at::IntArrayRef padding, 6458 at::IntArrayRef output_padding, 6459 at::IntArrayRef stride, 6460 at::IntArrayRef dilation, 6461 int64_t groups, 6462 c10::string_view attr, 6463 torch::List<c10::optional<at::Scalar>> scalars, 6464 c10::optional<c10::string_view> algorithm)""" 6465 6466 def codegen(self, wrapper): 6467 wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( 6468 self.get_name(), 6469 self.python_kernel_name, 6470 self.cpp_kernel_name, 6471 self.codegen_args(), 6472 self.cpp_op_schema, 6473 self.cpp_kernel_key, 6474 ) 6475 6476 @classmethod 6477 def create( 6478 cls, 6479 x: "TensorBox", 6480 weight: "TensorBox", 6481 bias: "TensorBox", 6482 padding_: List[int], 6483 output_padding_: List[int], 6484 stride_: List[int], 6485 dilation_: List[int], 6486 groups_: int, 6487 attr, 6488 scalars: Optional[List[Any]], 6489 algorithm, 6490 ): 6491 transposed = True 6492 ( 6493 inputs, 6494 constant_args, 6495 kernel_layout, 6496 _, 6497 ) = _prepare_convolution_fusion_create( 6498 cls, 6499 x, 6500 weight, 6501 bias, 6502 padding_, 6503 stride_, 6504 dilation_, 6505 groups_, 6506 transposed, 6507 output_padding_, 6508 ) 6509 constant_args = constant_args + [ 6510 attr, 6511 may_convert_to_optional(scalars), 6512 algorithm, 6513 ] 6514 return ConvolutionTransposeUnary( 6515 layout=kernel_layout, 6516 inputs=inputs, 6517 constant_args=constant_args, 6518 ) 6519 6520 6521class MkldnnRnnLayer(ExternKernelAlloc): 6522 def __init__( 6523 self, 6524 layout, 6525 inputs, 6526 constant_args=(), 6527 ): 6528 super().__init__( 6529 layout, 6530 inputs, 6531 constant_args, 6532 None, 6533 python_kernel_name="aten.mkldnn_rnn_layer", 6534 cpp_kernel_name="at::mkldnn_rnn_layer", 6535 ) 6536 6537 @classmethod 6538 def create( 6539 cls, 6540 x: "TensorBox", 6541 w0: "TensorBox", 6542 w1: "TensorBox", 6543 w2: "TensorBox", 6544 w3: "TensorBox", 6545 hx: "TensorBox", 6546 cx: "TensorBox", 6547 reverse: bool, 6548 batch_sizes: List[int], 6549 mode: int, 6550 hidden_size: int, 6551 num_layers: int, 6552 has_biases: bool, 6553 bidirectional: bool, 6554 batch_first: bool, 6555 train: bool, 6556 ): 6557 x = cls.require_stride1(cls.realize_input(x)) 6558 # If batch_first, x has been permuted in lstm before entering the mkldnn_rnn_layer. 6559 # Make sure x is contiguous in batch_first case. 6560 x.freeze_layout() 6561 w0 = cls.require_stride1(cls.realize_input(w0)) 6562 w1 = cls.require_stride1(cls.realize_input(w1)) 6563 w2 = cls.require_stride1(cls.realize_input(w2)) 6564 w3 = cls.require_stride1(cls.realize_input(w3)) 6565 hx = cls.require_stride1(cls.realize_input(hx)) 6566 hx.freeze_layout() 6567 cx = cls.require_stride1(cls.realize_input(cx)) 6568 cx.freeze_layout() 6569 6570 input_size = x.get_size() 6571 assert len(input_size) == 3, "Expect lstm input to be 3D" 6572 # batch_first is handled in the lstm OP. When entering 6573 # rnn_layer here, we'll always have batch_first = False 6574 seq_length, mini_batch, input_size = input_size 6575 output_shape = [seq_length, mini_batch, hidden_size] 6576 6577 hy_shape = hx.get_size() 6578 cy_shape = cx.get_size() 6579 6580 res: List[IRNode] = [] 6581 6582 inputs = [x, w0, w1, w2, w3, hx, cx] 6583 constant_args = [ 6584 reverse, 6585 batch_sizes, 6586 mode, 6587 hidden_size, 6588 num_layers, 6589 has_biases, 6590 bidirectional, 6591 batch_first, 6592 train, 6593 ] 6594 6595 packed = MkldnnRnnLayer( 6596 MultiOutputLayout(x.get_device()), 6597 inputs=inputs, 6598 constant_args=constant_args, 6599 ) 6600 6601 def get_strides_of_lstm_output(output_shape, batch_first): 6602 assert len(output_shape) == 3, "Expect output_shape to be 3D" 6603 return FlexibleLayout.contiguous_strides(output_shape) 6604 6605 output_sizes = [output_shape, hy_shape, cy_shape] 6606 output_strides = [ 6607 get_strides_of_lstm_output(output_shape, batch_first), 6608 FlexibleLayout.contiguous_strides(hy_shape), 6609 FlexibleLayout.contiguous_strides(cy_shape), 6610 ] 6611 output_ir = [ 6612 MultiOutput( 6613 FixedLayout( 6614 x.get_device(), 6615 x.get_dtype(), 6616 output_size, 6617 output_stride, 6618 ), 6619 packed, 6620 [(tuple, i)], 6621 ) 6622 for i, (output_size, output_stride) in enumerate( 6623 zip(output_sizes, output_strides) 6624 ) 6625 ] 6626 6627 return output_ir 6628 6629 6630class QConvPointWisePT2E(ExternKernelAlloc): 6631 def __init__( 6632 self, 6633 layout, 6634 inputs, 6635 constant_args=(), 6636 ): 6637 """ 6638 if bias is not None 6639 - inputs = [x, w, b, weight_scale, weight_zp] 6640 - const_args is: [stride, padding, dilation, groups, x_scale, x_zp, o_inv_scale, o_zp, 6641 fp32_output, unary_attr, unary_scalars, unary_algorithm] 6642 else 6643 - inputs = [x, w, weight_scale, weight_zp] 6644 - const_args is: [bias, stride, padding, dilation, groups, x_scale, x_zp, o_inv_scale, o_zp, 6645 fp32_output, unary_attr, unary_scalars, unary_algorithm] 6646 """ 6647 self.has_bias = len(inputs) == 5 6648 super().__init__( 6649 layout, 6650 inputs, 6651 constant_args, 6652 None, 6653 python_kernel_name="torch.ops.onednn.qconv2d_pointwise", 6654 cpp_kernel_name="onednn::qconv2d_pointwise", 6655 ) 6656 self.cpp_kernel_key = "qconv2d_pointwise" 6657 self.cpp_op_schema = """ 6658 at::Tensor( 6659 at::Tensor act, 6660 double act_scale, 6661 int64_t act_zero_point, 6662 at::Tensor weight, 6663 at::Tensor weight_scales, 6664 at::Tensor weight_zero_points, 6665 c10::optional<at::Tensor> bias, 6666 torch::List<int64_t> stride, 6667 torch::List<int64_t> padding, 6668 torch::List<int64_t> dilation, 6669 int64_t groups, 6670 double output_scale, 6671 int64_t output_zero_point, 6672 c10::optional<c10::ScalarType> output_dtype, 6673 c10::string_view attr, 6674 torch::List<c10::optional<at::Scalar>> scalars, 6675 c10::optional<c10::string_view> algorithm)""" 6676 6677 def codegen(self, wrapper): 6678 # Parser the inputs and constant 6679 args = [x.codegen_reference() for x in self.inputs] 6680 const_args = [] 6681 const_args.extend(self.codegen_const_args()) 6682 6683 x = args[0] 6684 packed_weight = args[1] 6685 bias = args[2] if self.has_bias else const_args[0] 6686 w_scale, w_zp = args[-2], args[-1] 6687 ( 6688 stride, 6689 padding, 6690 dilation, 6691 groups, 6692 x_scale, 6693 x_zp, 6694 o_inv_scale, 6695 o_zp, 6696 output_dtype, 6697 unary_attr, 6698 unary_scalars, 6699 unary_algorithm, 6700 ) = const_args[-12:] 6701 6702 codegen_args = ( 6703 x, 6704 x_scale, 6705 x_zp, 6706 packed_weight, 6707 w_scale, 6708 w_zp, 6709 bias, 6710 stride, 6711 padding, 6712 dilation, 6713 groups, 6714 o_inv_scale, 6715 o_zp, 6716 output_dtype, 6717 unary_attr, 6718 unary_scalars, 6719 unary_algorithm, 6720 ) 6721 wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( 6722 self.get_name(), 6723 self.python_kernel_name, 6724 self.cpp_kernel_name, 6725 codegen_args, 6726 self.cpp_op_schema, 6727 self.cpp_kernel_key, 6728 ) 6729 if isinstance(self.layout, Layout): 6730 self.codegen_size_asserts(wrapper) 6731 6732 @classmethod 6733 def create( 6734 cls, 6735 x: "TensorBox", 6736 x_scale: float, 6737 x_zp: int, 6738 weight: "TensorBox", # packed_weight 6739 w_scale: "TensorBox", 6740 w_zp: "TensorBox", 6741 bias: "TensorBox", 6742 stride_: List[int], 6743 padding_: List[int], 6744 dilation_: List[int], 6745 groups: int, 6746 o_inv_scale: float, 6747 output_zero_point: int, 6748 output_dtype, 6749 unary_attr, 6750 unary_scalars, 6751 unary_algorithm, 6752 ): 6753 transposed = False 6754 output_padding = None 6755 (inputs, constant_args, kernel_layout, _) = _prepare_convolution_fusion_create( 6756 cls, 6757 x, 6758 weight, 6759 bias, 6760 padding_, 6761 stride_, 6762 dilation_, 6763 groups, 6764 transposed, 6765 output_padding, 6766 ) 6767 # swap padding and stride to align with functional conv arg order 6768 if bias is None: 6769 constant_args[1], constant_args[2] = constant_args[2], constant_args[1] 6770 else: 6771 constant_args[0], constant_args[1] = constant_args[1], constant_args[0] 6772 6773 w_scale.realize() 6774 w_zp.realize() 6775 inputs = inputs + [w_scale, w_zp] 6776 constant_args = constant_args + [ 6777 x_scale, 6778 x_zp, 6779 o_inv_scale, 6780 output_zero_point, 6781 output_dtype, 6782 unary_attr, 6783 may_convert_to_optional(unary_scalars), 6784 unary_algorithm, 6785 ] 6786 6787 if output_dtype is not None: 6788 assert output_dtype in [torch.float32, torch.bfloat16] 6789 # in _prepare_convolution_fusion_create, we use x.dtype (uint8) to create kernel_layout 6790 # if we set output_dtype is not None, the output buf should be output_dtype instead of uint8. 6791 kernel_layout.dtype = output_dtype 6792 6793 return QConvPointWisePT2E( 6794 layout=kernel_layout, 6795 inputs=inputs, 6796 constant_args=constant_args, 6797 ) 6798 6799 6800class QConvPointWiseBinaryPT2E(ExternKernelAlloc): 6801 def __init__( 6802 self, 6803 layout, 6804 inputs, 6805 constant_args=(), 6806 ): 6807 """ 6808 Needs input/weight/output qparams 6809 if bias is not None 6810 - inputs = [x, w, b, accum, w_scale, w_zp] 6811 - const_args = [stride, padding, dilation, groups, x_scale, x_zp, accum_scale, accum_zp, o_inv_scale, o_zp, 6812 fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm] 6813 else 6814 - inputs = [x, w, accum, w_scale, w_zp] 6815 - const_args = const_args is: [bias, stride, padding, dilation, groups, x_scale, x_zp, accum_scale, 6816 accum_zp, o_inv_scale, o_zp, fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm] 6817 """ 6818 self.has_bias = len(inputs) == 6 6819 self.idx_for_inplace_sum = 3 if self.has_bias else 2 6820 super().__init__( 6821 layout, 6822 inputs, 6823 constant_args, 6824 None, 6825 python_kernel_name="torch.ops.onednn.qconv2d_pointwise.binary", 6826 cpp_kernel_name="onednn::qconv2d_pointwise", 6827 ) 6828 self.cpp_kernel_overload_name = "binary" 6829 self.cpp_kernel_key = "qconv2d_pointwise_binary" 6830 self.cpp_op_schema = """ 6831 at::Tensor( 6832 at::Tensor act, 6833 double act_scale, 6834 int64_t act_zero_point, 6835 at::Tensor accum, 6836 double accum_scale, 6837 int64_t accum_zero_point, 6838 at::Tensor weight, 6839 at::Tensor weight_scales, 6840 at::Tensor weight_zero_points, 6841 c10::optional<at::Tensor> bias, 6842 torch::List<int64_t> stride, 6843 torch::List<int64_t> padding, 6844 torch::List<int64_t> dilation, 6845 int64_t groups, 6846 double output_scale, 6847 int64_t output_zero_point, 6848 c10::optional<c10::ScalarType> output_dtype, 6849 c10::string_view binary_attr, 6850 c10::optional<at::Scalar> alpha, 6851 c10::optional<c10::string_view> attr, 6852 torch::List<c10::optional<at::Scalar>> scalars, 6853 c10::optional<c10::string_view> algorithm)""" 6854 6855 def codegen(self, wrapper): 6856 # Parser the inputs and constant 6857 args = [x.codegen_reference() for x in self.inputs] 6858 const_args = [] 6859 const_args.extend(self.codegen_const_args()) 6860 6861 x = args[0] 6862 packed_weight = args[1] 6863 bias = args[2] if self.has_bias else const_args[0] 6864 accum, w_scale, w_zp = args[-3], args[-2], args[-1] 6865 ( 6866 stride, 6867 padding, 6868 dilation, 6869 groups, 6870 x_scale, 6871 x_zp, 6872 accum_scale, 6873 accum_zp, 6874 o_inv_scale, 6875 o_zp, 6876 output_dtype, 6877 binary_attr, 6878 alpha, 6879 unary_attr, 6880 unary_scalars, 6881 unary_algorithm, 6882 ) = const_args[-16:] 6883 conv_args = ( 6884 x, 6885 x_scale, 6886 x_zp, 6887 accum, 6888 accum_scale, 6889 accum_zp, 6890 packed_weight, 6891 w_scale, 6892 w_zp, 6893 bias, 6894 stride, 6895 padding, 6896 dilation, 6897 groups, 6898 o_inv_scale, 6899 o_zp, 6900 output_dtype, 6901 binary_attr, 6902 alpha, 6903 unary_attr, 6904 unary_scalars, 6905 unary_algorithm, 6906 ) 6907 wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( 6908 self.get_name(), 6909 self.python_kernel_name, 6910 self.cpp_kernel_name, 6911 conv_args, 6912 self.cpp_op_schema, 6913 self.cpp_kernel_key, 6914 self.cpp_kernel_overload_name, 6915 ) 6916 if isinstance(self.layout, Layout): 6917 self.codegen_size_asserts(wrapper) 6918 6919 def get_mutation_names(self): 6920 return [self.inputs[self.idx_for_inplace_sum].get_name()] 6921 6922 def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: 6923 return set() 6924 6925 @classmethod 6926 def create( 6927 cls, 6928 x: "TensorBox", 6929 x_scale, 6930 x_zp, 6931 accum: "TensorBox", 6932 accum_scale, 6933 accum_zp, 6934 weight: "TensorBox", # packed_weight 6935 w_scale, 6936 w_zp, 6937 bias: "TensorBox", 6938 stride_: List[int], 6939 padding_: List[int], 6940 dilation_: List[int], 6941 groups: int, 6942 o_inv_scale: "TensorBox", 6943 output_zero_point: "TensorBox", 6944 output_dtype, 6945 binary_attr, 6946 alpha, 6947 unary_attr, 6948 unary_scalars, 6949 unary_algorithm, 6950 ): 6951 transposed = False 6952 output_padding = None 6953 ( 6954 inputs, 6955 constant_args, 6956 kernel_layout, 6957 req_stride_order, 6958 ) = _prepare_convolution_fusion_create( 6959 cls, 6960 x, 6961 weight, 6962 bias, 6963 padding_, 6964 stride_, 6965 dilation_, 6966 groups, 6967 transposed, 6968 output_padding, 6969 ) 6970 6971 accum = cls.require_stride_order(accum, req_stride_order) 6972 inputs.append(accum) 6973 6974 # swap padding and stride to align with functional conv arg order 6975 if bias is None: 6976 constant_args[1], constant_args[2] = constant_args[2], constant_args[1] 6977 else: 6978 constant_args[0], constant_args[1] = constant_args[1], constant_args[0] 6979 6980 w_scale.realize() 6981 w_zp.realize() 6982 inputs = inputs + [w_scale, w_zp] 6983 constant_args = constant_args + [ 6984 x_scale, 6985 x_zp, 6986 accum_scale, 6987 accum_zp, 6988 o_inv_scale, 6989 output_zero_point, 6990 output_dtype, 6991 binary_attr, 6992 alpha, 6993 unary_attr, 6994 may_convert_to_optional(unary_scalars), 6995 unary_algorithm, 6996 ] 6997 6998 assert ( 6999 binary_attr == "sum" 7000 ), "For now, only post op sum is supported in QConvPointWiseBinaryPT2E." 7001 7002 packed = QConvPointWiseBinaryPT2E( 7003 layout=NoneLayout(accum.get_device()), 7004 inputs=inputs, 7005 constant_args=constant_args, 7006 ) 7007 mark_node_as_mutating(packed, accum) 7008 7009 # Return accum since it has been inplace changed. 7010 return packed.inputs[packed.idx_for_inplace_sum] 7011 7012 7013class QLinearPointwisePT2E(ExternKernelAlloc): 7014 def __init__( 7015 self, 7016 layout, 7017 inputs, 7018 constant_args=(), 7019 has_bias=True, 7020 x_scale_zp_are_tensors=False, 7021 ): 7022 """ 7023 if bias is not None 7024 - inputs = [x, w, b, weight_scale, weight_zp] 7025 - const_args is: [x_scale, x_zp, o_inv_scale, o_zp, 7026 fp32_output, unary_attr, unary_scalars, unary_algorithm] 7027 else 7028 - inputs = [x, w, weight_scale, weight_zp] 7029 - const_args is: [bias, x_scale, x_zp, o_inv_scale, o_zp, 7030 fp32_output, unary_attr, unary_scalars, unary_algorithm] 7031 """ 7032 self.has_bias = has_bias 7033 self.x_scale_zp_are_tensors = x_scale_zp_are_tensors 7034 super().__init__( 7035 layout, 7036 inputs, 7037 constant_args, 7038 None, 7039 python_kernel_name=( 7040 "torch.ops.onednn.qlinear_pointwise.tensor" 7041 if x_scale_zp_are_tensors 7042 else "torch.ops.onednn.qlinear_pointwise.default" 7043 ), 7044 cpp_kernel_name="onednn::qlinear_pointwise", 7045 ) 7046 self.cpp_kernel_overload_name = "tensor" if x_scale_zp_are_tensors else "" 7047 self.cpp_kernel_key = "qlinear_pointwise" 7048 x_scale_type_str, x_zp_type_str = ( 7049 ("at::Tensor", "at::Tensor") 7050 if x_scale_zp_are_tensors 7051 else ("double", "int64_t") 7052 ) 7053 self.cpp_op_schema = f""" 7054 at::Tensor( 7055 at::Tensor act, 7056 {x_scale_type_str} act_scale, 7057 {x_zp_type_str} act_zero_point, 7058 at::Tensor weight, 7059 at::Tensor weight_scales, 7060 at::Tensor weight_zero_points, 7061 c10::optional<at::Tensor> bias, 7062 double output_scale, 7063 int64_t output_zero_point, 7064 c10::optional<c10::ScalarType> output_dtype, 7065 c10::string_view post_op_name, 7066 torch::List<c10::optional<at::Scalar>> post_op_args, 7067 c10::string_view post_op_algorithm)""" 7068 7069 def codegen(self, wrapper): 7070 # Parser the inputs and constant 7071 args = [x.codegen_reference() for x in self.inputs] 7072 const_args = [] 7073 const_args.extend(self.codegen_const_args()) 7074 7075 x = args[0] 7076 packed_weight = args[1] 7077 bias = args[2] if self.has_bias else const_args[0] 7078 w_scale, w_zp = args[-2], args[-1] 7079 if self.x_scale_zp_are_tensors: 7080 assert len(args) >= 4 7081 x_scale, x_zp = args[-4], args[-3] 7082 ( 7083 o_inv_scale, 7084 o_zp, 7085 output_dtype, 7086 unary_attr, 7087 unary_scalars, 7088 unary_algorithm, 7089 ) = const_args[-6:] 7090 else: 7091 assert len(const_args) >= 8 7092 ( 7093 x_scale, 7094 x_zp, 7095 o_inv_scale, 7096 o_zp, 7097 output_dtype, 7098 unary_attr, 7099 unary_scalars, 7100 unary_algorithm, 7101 ) = const_args[-8:] 7102 7103 codegen_args = ( 7104 x, 7105 x_scale, 7106 x_zp, 7107 packed_weight, 7108 w_scale, 7109 w_zp, 7110 bias, 7111 o_inv_scale, 7112 o_zp, 7113 output_dtype, 7114 unary_attr, 7115 unary_scalars, 7116 unary_algorithm, 7117 ) 7118 wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( 7119 self.get_name(), 7120 self.python_kernel_name, 7121 self.cpp_kernel_name, 7122 codegen_args, 7123 self.cpp_op_schema, 7124 self.cpp_kernel_key, 7125 self.cpp_kernel_overload_name, 7126 ) 7127 if isinstance(self.layout, Layout): 7128 self.codegen_size_asserts(wrapper) 7129 7130 @classmethod 7131 def create( 7132 cls, 7133 x: "TensorBox", 7134 x_scale: float, 7135 x_zp: int, 7136 weight: "TensorBox", # packed_weight 7137 w_scale: "TensorBox", 7138 w_zp: "TensorBox", 7139 bias: "TensorBox", 7140 o_inv_scale: float, 7141 output_zero_point: int, 7142 output_dtype, 7143 unary_attr, 7144 unary_scalars, 7145 unary_algorithm, 7146 ): 7147 (inputs, constant_args, kernel_layout, _) = _prepare_linear_fusion_create( 7148 cls, 7149 x, 7150 weight, 7151 bias, 7152 ) 7153 7154 if isinstance(x_scale, TensorBox) and isinstance(x_zp, TensorBox): 7155 x_scale.realize() 7156 x_zp.realize() 7157 inputs = inputs + [x_scale, x_zp] 7158 x_scale_zp_are_tensors = True 7159 else: 7160 assert isinstance(x_scale, float) and isinstance(x_zp, int) 7161 constant_args = constant_args + [x_scale, x_zp] 7162 x_scale_zp_are_tensors = False 7163 w_scale.realize() 7164 w_zp.realize() 7165 inputs = inputs + [w_scale, w_zp] 7166 constant_args = constant_args + [ 7167 o_inv_scale, 7168 output_zero_point, 7169 output_dtype, 7170 unary_attr, 7171 may_convert_to_optional(unary_scalars), 7172 unary_algorithm, 7173 ] 7174 7175 if output_dtype is not None: 7176 assert output_dtype in [torch.float32, torch.bfloat16] 7177 # in _prepare_linear_fusion_create, we use x.dtype (uint8) to create kernel_layout 7178 # if we set fp32_output, the output buf should be dtype float32 instead of uint8. 7179 kernel_layout.dtype = output_dtype 7180 7181 return QLinearPointwisePT2E( 7182 layout=kernel_layout, 7183 inputs=inputs, 7184 constant_args=constant_args, 7185 has_bias=(bias is not None), 7186 x_scale_zp_are_tensors=x_scale_zp_are_tensors, 7187 ) 7188 7189 7190class QLinearPointwiseBinaryPT2E(ExternKernelAlloc): 7191 def __init__( 7192 self, 7193 layout, 7194 inputs, 7195 constant_args=(), 7196 has_bias=True, 7197 x_scale_zp_are_tensors=False, 7198 ): 7199 """ 7200 if bias is not None 7201 - inputs = [x, w, b, weight_scale, weight_zp, x2] 7202 - const_args is: [x_scale, x_zp, o_inv_scale, o_zp, 7203 fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm] 7204 else 7205 - inputs = [x, w, weight_scale, weight_zp, x2] 7206 - const_args is: [bias, x_scale, x_zp, o_inv_scale, o_zp, 7207 fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm] 7208 """ 7209 self.has_bias = has_bias 7210 self.x_scale_zp_are_tensors = x_scale_zp_are_tensors 7211 super().__init__( 7212 layout, 7213 inputs, 7214 constant_args, 7215 None, 7216 python_kernel_name=( 7217 "torch.ops.onednn.qlinear_pointwise.binary_tensor" 7218 if x_scale_zp_are_tensors 7219 else "torch.ops.onednn.qlinear_pointwise.binary" 7220 ), 7221 cpp_kernel_name="onednn::qlinear_pointwise", 7222 ) 7223 self.cpp_kernel_overload_name = ( 7224 "binary_tensor" if x_scale_zp_are_tensors else "binary" 7225 ) 7226 self.cpp_kernel_key = "qlinear_pointwise_binary" 7227 x_scale_type_str, x_zp_type_str = ( 7228 ("at::Tensor", "at::Tensor") 7229 if x_scale_zp_are_tensors 7230 else ("double", "int64_t") 7231 ) 7232 self.cpp_op_schema = f""" 7233 at::Tensor( 7234 at::Tensor act, 7235 {x_scale_type_str} act_scale, 7236 {x_zp_type_str} act_zero_point, 7237 at::Tensor weight, 7238 at::Tensor weight_scales, 7239 at::Tensor weight_zero_points, 7240 c10::optional<at::Tensor> bias, 7241 double inv_output_scale, 7242 int64_t output_zero_point, 7243 c10::optional<c10::ScalarType> output_dtype, 7244 c10::optional<at::Tensor> other, 7245 double other_scale, 7246 int64_t other_zero_point, 7247 c10::string_view binary_post_op, 7248 double binary_alpha, 7249 c10::string_view unary_post_op, 7250 torch::List<c10::optional<at::Scalar>> unary_post_op_args, 7251 c10::string_view unary_post_op_algorithm)""" 7252 7253 def codegen(self, wrapper): 7254 # Parser the inputs and constant 7255 args = [x.codegen_reference() for x in self.inputs] 7256 const_args = [] 7257 const_args.extend(self.codegen_const_args()) 7258 7259 x = args[0] 7260 packed_weight = args[1] 7261 bias = args[2] if self.has_bias else const_args[0] 7262 w_scale, w_zp, other = args[-3], args[-2], args[-1] 7263 if self.x_scale_zp_are_tensors: 7264 assert len(args) >= 5 7265 x_scale, x_zp = args[-5], args[-4] 7266 ( 7267 o_inv_scale, 7268 o_zp, 7269 output_dtype, 7270 other_scale, 7271 other_zp, 7272 binary_attr, 7273 alpha, 7274 unary_attr, 7275 unary_scalars, 7276 unary_algorithm, 7277 ) = const_args[-10:] 7278 else: 7279 assert len(const_args) >= 8 7280 ( 7281 x_scale, 7282 x_zp, 7283 o_inv_scale, 7284 o_zp, 7285 output_dtype, 7286 other_scale, 7287 other_zp, 7288 binary_attr, 7289 alpha, 7290 unary_attr, 7291 unary_scalars, 7292 unary_algorithm, 7293 ) = const_args[-12:] 7294 7295 codegen_args = ( 7296 x, 7297 x_scale, 7298 x_zp, 7299 packed_weight, 7300 w_scale, 7301 w_zp, 7302 bias, 7303 o_inv_scale, 7304 o_zp, 7305 output_dtype, 7306 other, 7307 other_scale, 7308 other_zp, 7309 binary_attr, 7310 alpha, 7311 unary_attr, 7312 unary_scalars, 7313 unary_algorithm, 7314 ) 7315 wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( 7316 self.get_name(), 7317 self.python_kernel_name, 7318 self.cpp_kernel_name, 7319 codegen_args, 7320 self.cpp_op_schema, 7321 self.cpp_kernel_key, 7322 self.cpp_kernel_overload_name, 7323 ) 7324 if isinstance(self.layout, Layout): 7325 self.codegen_size_asserts(wrapper) 7326 7327 @classmethod 7328 def create( 7329 cls, 7330 x: "TensorBox", 7331 x_scale: float, 7332 x_zp: int, 7333 weight: "TensorBox", # packed_weight 7334 w_scale: "TensorBox", 7335 w_zp: "TensorBox", 7336 bias: "TensorBox", 7337 o_inv_scale: float, 7338 output_zero_point: int, 7339 output_dtype, 7340 other: "TensorBox", 7341 other_scale, 7342 other_zp, 7343 binary_attr, 7344 alpha, 7345 unary_attr, 7346 unary_scalars, 7347 unary_algorithm, 7348 ): 7349 ( 7350 inputs, 7351 constant_args, 7352 kernel_layout, 7353 req_stride_order, 7354 ) = _prepare_linear_fusion_create( 7355 cls, 7356 x, 7357 weight, 7358 bias, 7359 ) 7360 7361 if isinstance(x_scale, TensorBox) and isinstance(x_zp, TensorBox): 7362 x_scale.realize() 7363 x_zp.realize() 7364 inputs = inputs + [x_scale, x_zp] 7365 x_scale_zp_are_tensors = True 7366 else: 7367 assert isinstance(x_scale, float) and isinstance(x_zp, int) 7368 constant_args = constant_args + [x_scale, x_zp] 7369 x_scale_zp_are_tensors = False 7370 w_scale.realize() 7371 w_zp.realize() 7372 inputs = inputs + [w_scale, w_zp] 7373 if binary_attr == "sum": 7374 other = cls.require_stride_order(other, req_stride_order) 7375 inputs.append(other) 7376 constant_args = constant_args + [ 7377 o_inv_scale, 7378 output_zero_point, 7379 output_dtype, 7380 other_scale, 7381 other_zp, 7382 binary_attr, 7383 alpha, 7384 unary_attr, 7385 may_convert_to_optional(unary_scalars), 7386 unary_algorithm, 7387 ] 7388 7389 if binary_attr == "sum": 7390 packed = QLinearPointwiseBinaryPT2E( 7391 layout=NoneLayout(other.get_device()), 7392 inputs=inputs, 7393 constant_args=constant_args, 7394 has_bias=(bias is not None), 7395 x_scale_zp_are_tensors=x_scale_zp_are_tensors, 7396 ) 7397 mark_node_as_mutating(packed, other) 7398 # Return other since it has been inplace changed. 7399 return packed.inputs[-1] 7400 7401 if output_dtype is not None: 7402 assert output_dtype in [torch.float32, torch.bfloat16] 7403 # in _prepare_linear_fusion_create, we use x.dtype (uint8) to create kernel_layout 7404 # if we set fp32_output, the output buf should be dtype float32 instead of uint8. 7405 kernel_layout.dtype = output_dtype 7406 7407 return QLinearPointwiseBinaryPT2E( 7408 layout=kernel_layout, 7409 inputs=inputs, 7410 constant_args=constant_args, 7411 has_bias=(bias is not None), 7412 x_scale_zp_are_tensors=x_scale_zp_are_tensors, 7413 ) 7414 7415 7416@dataclasses.dataclass 7417class MutableBox(IRNode): 7418 """ 7419 TensorBox / StorageBox allow in-place mutation of Tensors 7420 """ 7421 7422 data: IRNode 7423 7424 def __getattr__(self, name): 7425 fn = getattr(self.data, name) 7426 if callable(fn): 7427 return fn 7428 raise AttributeError(f"{type(self.data).__name__}.{name} not callable") 7429 7430 def realize(self): 7431 return self.data.realize() 7432 7433 def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: 7434 return self.data.get_unbacked_symbol_uses() 7435 7436 def codegen_reference(self, writer=None): 7437 return self.data.codegen_reference(writer) 7438 7439 @property 7440 def layout(self): 7441 return self.data.get_layout() 7442 7443 def get_layout(self): 7444 return self.layout 7445 7446 def get_size(self): 7447 return self.data.get_size() 7448 7449 @property 7450 def dtype(self): 7451 return self.data.dtype 7452 7453 def __str__(self): 7454 if isinstance(self.data, MutableBox): 7455 line0 = f"{type(self).__name__}({type(self.data).__name__}(" 7456 endl = "))" 7457 inner = self.data.data 7458 else: 7459 line0 = f"{type(self).__name__}(" 7460 inner = self.data 7461 endl = ")" 7462 7463 lines = [ 7464 line0, 7465 indent(str(inner)), 7466 endl, 7467 ] 7468 return "\n".join(lines) 7469 7470 __repr__ = __str__ 7471 7472 7473class TensorBox(MutableBox): 7474 @staticmethod 7475 def create(data): 7476 return TensorBox(StorageBox(data)) 7477 7478 7479class StorageBox(MutableBox): 7480 def is_input_buffer(self): 7481 if isinstance(self.data, (InputBuffer, ReinterpretView)): 7482 return self.data.get_name() in V.graph.graph_inputs 7483 return False 7484 7485 def is_module_buffer(self): 7486 return ( 7487 isinstance(self.data, (ConstantBuffer)) 7488 and self.data.get_name() in V.graph.constants 7489 ) 7490 7491 def realize(self): 7492 if isinstance( 7493 self.data, 7494 ( 7495 ComputedBuffer, 7496 InputsKernel, 7497 InputBuffer, 7498 ReinterpretView, 7499 TemplateBuffer, 7500 ), 7501 ): 7502 return self.data.get_name() 7503 assert isinstance(self.data, (Pointwise, Reduction, Scan)), type(self.data) 7504 origin_node = self.data.get_origin_node() 7505 traceback = self.data.get_traceback() 7506 self.data = ComputedBuffer( 7507 name=None, 7508 layout=FlexibleLayout( 7509 device=self.data.get_device(), 7510 dtype=self.data.get_dtype(), 7511 size=self.data.get_size(), 7512 ), 7513 data=self.data, 7514 ) 7515 self.data.name = V.graph.register_buffer(self.data) 7516 self.data.origins = self.origins 7517 self.data.origin_node = origin_node 7518 self.data.traceback = traceback 7519 return self.data.name 7520 7521 def realize_hint(self): 7522 """ 7523 Called on buffers we expect to be forced to realize later. 7524 """ 7525 if ( 7526 isinstance(self.data, (Pointwise, Reduction)) 7527 and self.num_reads() > 1 7528 and self.is_pointwise_non_scalar_tensor_num_reads_larger_than_one() 7529 ): 7530 self.realize() 7531 7532 def has_exceeded_max_reads(self): 7533 return isinstance(self.data, Pointwise) and ( 7534 self.num_reads() > config.realize_acc_reads_threshold 7535 or self.has_large_inner_fn() 7536 ) 7537 7538 def mark_reuse(self, users): 7539 """ 7540 A heuristic to decide if we should realize a tensor 7541 that is used multiple times. 7542 """ 7543 7544 def should_realize_on_cpu(loops: Union[Pointwise, Reduction]): 7545 """ 7546 The heuristic for realizing reused result of heavy ops on cpu 7547 """ 7548 heavy_ops = ["exp"] # a list of heavy ops 7549 fn_str = loops.inner_fn_str() 7550 return any((op + "(") in fn_str for op in heavy_ops) 7551 7552 if ( 7553 users > 1 7554 and isinstance(self.data, (Pointwise, Reduction)) 7555 and ( 7556 self.num_reads() > config.realize_reads_threshold 7557 or self.has_large_inner_fn() 7558 or (is_cpu(self.data) and should_realize_on_cpu(self.data)) 7559 ) 7560 ): 7561 self.realize() 7562 7563 @cache_on_self 7564 def num_reads(self): 7565 data = self.data 7566 if isinstance(data, (InputsKernel, InputBuffer, ReinterpretView)): 7567 return 1 7568 if isinstance(data, ComputedBuffer): 7569 read_writes = data.get_read_writes() 7570 else: 7571 assert isinstance(data, (Pointwise, Reduction)), type(data) 7572 read_writes = ComputedBuffer( 7573 name=None, 7574 layout=FlexibleLayout( 7575 device=data.get_device(), 7576 dtype=data.get_dtype(), 7577 size=data.get_size(), 7578 ), 7579 data=data, 7580 ).get_read_writes() 7581 return len(read_writes.reads) 7582 7583 @cache_on_self 7584 def is_pointwise_non_scalar_tensor_num_reads_larger_than_one(self): 7585 # Skip the check for non Pointwise instances 7586 return ( 7587 (sum(read.index != 0 for read in self.data.get_reads()) > 1) 7588 if isinstance(self.data, Pointwise) 7589 and all( 7590 not isinstance(read, dependencies.StarDep) 7591 for read in self.data.get_reads() 7592 ) 7593 else True 7594 ) 7595 7596 7597@dataclasses.dataclass 7598class Subgraph(IRNode): 7599 name: str 7600 graph_module: torch.fx.GraphModule 7601 graph: Optional["GraphLowering"] = None 7602 7603 7604def _has_aliased_buffers(buffers): 7605 buffers = [ 7606 buffer.unwrap_view() if isinstance(buffer, ReinterpretView) else buffer 7607 for buffer in buffers 7608 ] 7609 # assuming the same buffer is represented by the same IRNode object 7610 return len({id(buffer) for buffer in buffers}) < len(buffers) 7611 7612 7613@dataclasses.dataclass 7614class Conditional(ExternKernel): 7615 predicate: Optional[IRNode] = None 7616 operands: Optional[List[TensorBox]] = None 7617 true_subgraph: Optional[Subgraph] = None 7618 false_subgraph: Optional[Subgraph] = None 7619 outputs: Optional[List[MultiOutput]] = None 7620 7621 def __init__( 7622 self, 7623 predicate: IRNode, 7624 operands: List[TensorBox], 7625 true_subgraph: Subgraph, 7626 false_subgraph: Subgraph, 7627 layout: MultiOutputLayout, 7628 ): 7629 self.predicate = predicate 7630 self.operands = operands 7631 self.true_subgraph = true_subgraph 7632 self.false_subgraph = false_subgraph 7633 7634 inputs = [] 7635 if not isinstance(predicate, ShapeAsConstantBuffer): 7636 inputs.append(predicate) 7637 inputs.extend(operands) 7638 7639 super().__init__( 7640 name=None, 7641 layout=layout, # type: ignore[arg-type] 7642 inputs=inputs, # type: ignore[list-item] 7643 ) 7644 7645 self.name = V.graph.register_buffer(self) 7646 7647 @classmethod 7648 def create( 7649 cls, 7650 predicate: TensorBox, 7651 true_fn: Subgraph, 7652 false_fn: Subgraph, 7653 operands: List[TensorBox], 7654 ): 7655 predicate = cls.realize_input(predicate) 7656 operands = [cls.realize_input(x) for x in operands] 7657 7658 fx_operands = V.graph.current_node.args[-1] 7659 fake_operands = [x.meta["val"] for x in fx_operands] # type: ignore[union-attr] 7660 7661 for subgraph in (true_fn, false_fn): 7662 if subgraph.graph is None: 7663 # create and lower subgraphs 7664 subgraph.graph = V.graph.make_subgraph( 7665 gm=subgraph.graph_module, 7666 example_inputs=fake_operands, 7667 subgraph_name=subgraph.name, 7668 ) 7669 with V.set_graph_handler(subgraph.graph): 7670 subgraph.graph.run(*fake_operands) 7671 7672 true_outputs = true_fn.graph.graph_outputs # type: ignore[union-attr] 7673 false_outputs = true_fn.graph.graph_outputs # type: ignore[union-attr] 7674 7675 for name, outputs in (("true_fn", true_outputs), ("false_fn", false_outputs)): 7676 if _has_aliased_buffers(true_outputs): 7677 raise AssertionError( 7678 "Output aliasing is currently not supported in compiled torch.cond. " 7679 f"The outputs of the {name} subgraph of torch.cond are aliased: {outputs}" 7680 ) 7681 7682 # make sure true and false outputs are structurally equivalent 7683 assert len(true_outputs) == len(false_outputs), (true_outputs, false_outputs) 7684 for i, (to, fo) in enumerate(zip(true_outputs, false_outputs)): 7685 assert to.get_size() == fo.get_size(), (i, to, fo) 7686 assert to.get_stride() == fo.get_stride(), (i, to, fo) 7687 assert to.get_device() == fo.get_device(), (i, to, fo) 7688 assert to.get_dtype() == fo.get_dtype(), (i, to, fo) 7689 assert to.get_layout().offset == fo.get_layout().offset, (i, to, fo) 7690 7691 if not isinstance(predicate, ShapeAsConstantBuffer): 7692 # use predicate device for consistent codegen-ing 7693 device = predicate.get_device() 7694 else: 7695 # predicate is not a Tensor: use first operand's device 7696 assert ( 7697 len(operands) > 0 7698 ), "When predicate is not a Tensor, there must be at least one operand in torch.cond." 7699 device = operands[0].get_device() 7700 7701 conditional = Conditional( 7702 predicate=predicate, 7703 operands=operands, 7704 true_subgraph=true_fn, 7705 false_subgraph=false_fn, 7706 layout=MultiOutputLayout(device), 7707 ) 7708 7709 outputs = [ 7710 MultiOutput( 7711 FixedLayout( 7712 device=output.get_device(), 7713 dtype=output.get_dtype(), 7714 size=output.get_size(), 7715 stride=output.get_stride(), 7716 offset=output.get_layout().offset, 7717 ), 7718 conditional, 7719 [(list, i)], 7720 ) 7721 # as the true and false outputs are equivalent, 7722 # we can use either of them here as a "template" 7723 for i, output in enumerate(true_outputs) 7724 ] 7725 7726 conditional.outputs = outputs 7727 return outputs 7728 7729 def codegen(self, wrapper): 7730 wrapper.codegen_conditional(self) 7731 7732 7733@dataclasses.dataclass 7734class WhileLoop(ExternKernel): 7735 carried_inputs: Optional[List[TensorBox]] = None 7736 additional_inputs: Optional[List[TensorBox]] = None 7737 cond_subgraph: Optional[Subgraph] = None 7738 body_subgraph: Optional[Subgraph] = None 7739 outputs: Optional[List[MultiOutput]] = None 7740 7741 def __init__( 7742 self, 7743 carried_inputs: List[TensorBox], 7744 additional_inputs: List[TensorBox], 7745 cond_subgraph: Subgraph, 7746 body_subgraph: Subgraph, 7747 layout: MultiOutputLayout, 7748 ): 7749 self.carried_inputs = carried_inputs 7750 self.additional_inputs = additional_inputs 7751 self.cond_subgraph = cond_subgraph 7752 self.body_subgraph = body_subgraph 7753 7754 super().__init__( 7755 name=None, 7756 layout=layout, # type: ignore[arg-type] 7757 inputs=carried_inputs + additional_inputs, # type: ignore[list-item] 7758 ) 7759 7760 self.name = V.graph.register_buffer(self) 7761 7762 @classmethod 7763 def create( 7764 cls, 7765 cond_fn: Subgraph, 7766 body_fn: Subgraph, 7767 carried_inputs: List[TensorBox], 7768 additional_inputs: List[TensorBox], 7769 ): 7770 carried_inputs = [cls.realize_input(x) for x in carried_inputs] 7771 additional_inputs = [cls.realize_input(x) for x in additional_inputs] 7772 all_inputs = carried_inputs + additional_inputs 7773 7774 fx_all_inputs = V.graph.current_node.args[-2] + V.graph.current_node.args[-1] # type: ignore[operator] 7775 fake_all_inputs = [x.meta["val"] for x in fx_all_inputs] # type: ignore[union-attr] 7776 7777 for subgraph in (cond_fn, body_fn): 7778 if subgraph.graph is None: 7779 # create and lower subgraphs 7780 subgraph.graph = V.graph.make_subgraph( 7781 gm=subgraph.graph_module, 7782 example_inputs=fx_all_inputs, # type: ignore[arg-type] 7783 subgraph_name=subgraph.name, 7784 ) 7785 with V.set_graph_handler(subgraph.graph): 7786 subgraph.graph.run(*fake_all_inputs) 7787 7788 cond_outputs = cond_fn.graph.graph_outputs # type: ignore[union-attr] 7789 body_outputs = body_fn.graph.graph_outputs # type: ignore[union-attr] 7790 7791 if _has_aliased_buffers(body_outputs): 7792 raise AssertionError( 7793 "Output aliasing is currently not supported in compiled torch.while_loop. " 7794 f"The outputs of the body_fn subgraph of torch.while_loop are aliased: {body_outputs}" 7795 ) 7796 7797 # make sure cond_fn returns a boolean scalar Tensor 7798 assert len(cond_outputs) == 1, cond_outputs 7799 assert cond_outputs[0].get_dtype() == torch.bool, cond_outputs 7800 assert len(cond_outputs[0].get_size()) == 0, cond_outputs 7801 7802 assert ( 7803 len(all_inputs) > 0 7804 ), "torch.while_loop is assumed to have at least one operand." 7805 7806 device = all_inputs[0].get_device() 7807 7808 # make sure carried_inputs and body outputs are structurally equivalent 7809 assert len(carried_inputs) == len(body_outputs), (carried_inputs, body_outputs) 7810 for i, (op, bo) in enumerate(zip(carried_inputs, body_outputs)): 7811 assert op.get_size() == bo.get_size(), (i, op, bo) 7812 assert op.get_stride() == bo.get_stride(), (i, op, bo) 7813 # assume all carried_inputs and outputs are on the same device 7814 # as the MultiOutputLayout below requires single device 7815 assert op.get_device() == bo.get_device() == device, (i, op, bo, device) 7816 assert op.get_dtype() == bo.get_dtype(), (i, op, bo) 7817 assert op.get_layout().offset == bo.get_layout().offset, (i, op, bo) 7818 7819 while_loop = WhileLoop( 7820 carried_inputs=carried_inputs, 7821 additional_inputs=additional_inputs, 7822 cond_subgraph=cond_fn, 7823 body_subgraph=body_fn, 7824 # asserted above that there is at least one operand 7825 layout=MultiOutputLayout(device), 7826 ) 7827 7828 outputs = [ 7829 MultiOutput( 7830 FixedLayout( 7831 device=output.get_device(), 7832 dtype=output.get_dtype(), 7833 size=output.get_size(), 7834 stride=output.get_stride(), 7835 offset=output.get_layout().offset, 7836 ), 7837 while_loop, 7838 [(list, i)], 7839 ) 7840 for i, output in enumerate(body_outputs) 7841 ] 7842 7843 for inp, out in zip(carried_inputs, outputs): 7844 if inp.get_name() in V.graph.graph_inputs: 7845 # if a carried input of the while_loop is a graph input, 7846 # it can be returned as is when the number of iterations 7847 # is zero. due to this, we can't (generally) reuse the 7848 # output buffers corresponding to the graph inputs, as 7849 # the inputs may end up being mutated. 7850 V.graph.never_reuse_buffers.add(out.get_name()) 7851 7852 while_loop.outputs = outputs 7853 return outputs 7854 7855 def codegen(self, wrapper): 7856 wrapper.codegen_while_loop(self) 7857 7858 7859class EffectfulKernel(FallbackKernel): 7860 def __init__( 7861 self, 7862 layout, 7863 kernel, 7864 tensor_args, 7865 nontensor_args, 7866 unflatten_args, 7867 kwargs=None, 7868 *, 7869 unbacked_bindings=None, 7870 ): 7871 super().__init__( 7872 layout, 7873 kernel, 7874 tensor_args, 7875 nontensor_args, 7876 unflatten_args, 7877 kwargs=None, 7878 unbacked_bindings=unbacked_bindings, 7879 ) 7880 7881 from torch._higher_order_ops.effects import get_effect_key 7882 7883 effect_type = get_effect_key(kernel, (*nontensor_args, *tensor_args), kwargs) 7884 assert effect_type is not None 7885 self.effect_type = effect_type 7886 self.prev_effect_buffer = V.graph.effectful_ops.get(effect_type, None) 7887 V.graph.effectful_ops[effect_type] = self 7888 7889 def get_read_writes(self): 7890 read_writes = super().get_read_writes() 7891 7892 if self.prev_effect_buffer is not None: 7893 read_writes.reads.add( 7894 dependencies.StarDep(self.prev_effect_buffer.get_name()) 7895 ) 7896 7897 return read_writes 7898 7899 def has_side_effects(self): 7900 return True 7901 7902 7903@dataclasses.dataclass 7904class TorchBindObject(IRNode): 7905 name: str 7906 value: torch._C.ScriptObject 7907 7908 def get_name(self): 7909 return self.name 7910 7911 def get_device(self): 7912 return None # is there a device?? 7913 7914 def codegen_reference(self, writer=None): 7915 return self.name 7916 7917 7918class InterpreterShim(torch.fx.Interpreter): 7919 @staticmethod 7920 @functools.lru_cache(None) 7921 def _dummy_gm(): 7922 return torch.fx.symbolic_trace(identity) 7923 7924 def __init__(self, graph, submodules): 7925 # call super() with a placeholder to avoid constructing a 7926 # GraphModule which is very expensive (it does codegen). 7927 super().__init__(self._dummy_gm(), garbage_collect_values=False) 7928 self.module = self # type: ignore[assignment] 7929 self.graph = graph 7930 self.submodules = submodules 7931 self.extra_traceback = False 7932 self.fetch_attr = submodules.__getitem__ 7933 self.current_node = None 7934 7935 def run_node(self, n: torch.fx.Node) -> Any: 7936 self.current_node = n 7937 return super().run_node(n) 7938 7939 def run(self, *args, **kwargs): 7940 with V.set_interpreter_handler(self): 7941 return super().run(*args, **kwargs) 7942 7943 7944class LoopBody: 7945 """ 7946 Captures the body of a Loops subclass into an FX graph. Persists any 7947 indexing simplifications and makes it easier to analyze loop bodies. 7948 """ 7949 7950 def __init__(self, fn, args, var_ranges): 7951 super().__init__() 7952 self.var_ranges = var_ranges 7953 self.indexing_exprs = {} 7954 self.indexing_exprs_name = {} 7955 self.reads = [] 7956 self.writes = [] 7957 self.reads_name2expr = {} 7958 self.writes_name2expr = {} 7959 self.other = [] 7960 self.submodules = {"get_index": self.get_index} 7961 self.subblocks = {} 7962 self.indirect_vars = [] 7963 self.root_block = LoopBodyBlock(self, fn, args) 7964 self.indexing = None 7965 7966 @cache_on_self 7967 def get_nodes(self): 7968 all_graphs = itertools.chain( 7969 (self.root_block.graph,), 7970 (block.graph for block in self.subblocks.values()), 7971 ) 7972 return [node for graph in all_graphs for node in graph.nodes] 7973 7974 @cache_on_self 7975 def bounds(self): 7976 # Doing a local import to avoid dumping all the code here 7977 from .bounds import BoundVars 7978 7979 return BoundVars(self) 7980 7981 def debug_str(self): 7982 lines = [f"var_ranges = {dict(self.var_ranges)}"] 7983 lines.extend([f"{name} = {val}" for name, val in self.indexing_exprs.items()]) 7984 lines.extend( 7985 [ 7986 block.debug_str(name) 7987 for name, block in itertools.chain( 7988 [("body", self.root_block)], self.subblocks.items() 7989 ) 7990 ] 7991 ) 7992 return "\n".join(lines) 7993 7994 def add_index_expr(self, expr: sympy.Expr, category, buf_name): 7995 getattr(self, category).append(expr) 7996 if buf_name is not None: 7997 getattr(self, f"{category}_name2expr")[buf_name] = expr 7998 if expr not in self.indexing_exprs_name: 7999 name = f"index{len(self.indexing_exprs)}" 8000 self.indexing_exprs_name[expr] = name 8001 self.indexing_exprs[name] = expr 8002 return self.indexing_exprs_name[expr] 8003 8004 def add_submodule(self, block, prefix): 8005 """Not actually for nn.Modules, but subblocks in generated code are mapped to FX call_module opcodes""" 8006 if prefix[-1].isnumeric() and prefix not in self.submodules: 8007 name = prefix 8008 else: 8009 name = f"{prefix}{len(self.submodules)}" 8010 self.submodules[name] = block 8011 return name 8012 8013 def add_indirect(self, size): 8014 var = sympy_index_symbol_with_prefix(SymT.INDIRECT, len(self.indirect_vars)) 8015 self.indirect_vars.append(var) 8016 return var 8017 8018 def replace_indirect(self, old, new): 8019 """Swap in a variable used in indirect indexing""" 8020 if str(old) == str(new): 8021 return 8022 assert self.indexing is not None 8023 self.indexing = {k: sympy_subs(v, {old: new}) for k, v in self.indexing.items()} 8024 8025 def get_index(self, name): 8026 assert self.indexing is not None 8027 return self.indexing[name] 8028 8029 def __call__(self, *indices): 8030 index = list(itertools.chain.from_iterable(indices)) 8031 assert len(index) == len(self.var_ranges), (index, self.var_ranges) 8032 assert all(v not in self.var_ranges for v in index) 8033 replacements = dict(zip(self.var_ranges.keys(), index)) 8034 self.indexing = { 8035 name: sympy_subs(expr, replacements) 8036 for name, expr in self.indexing_exprs.items() 8037 } 8038 result = self.root_block() 8039 self.indexing = None 8040 return result 8041 8042 8043class LoopBodyBlock: 8044 """ 8045 Captures the body of a Loops subclass into an FX graph. 8046 In normal cases there will be a 1:1 mapping between LoopBody and 8047 LoopBodyBlock, hower in the case of ops.masked() the masked out 8048 operations will manifest as an extra LoopBodyBlock. 8049 """ 8050 8051 def __init__(self, body: LoopBody, fn: Callable[..., Any], args: List[Any]): 8052 self.body = body 8053 8054 def add_index(expr, category, buf_name=None): 8055 return tracer.create_proxy( 8056 "call_module", 8057 "get_index", 8058 (self.body.add_index_expr(expr, category, buf_name),), 8059 {}, 8060 ) 8061 8062 class CaptureIndexing(V.WrapperHandler): # type: ignore[name-defined] 8063 self.name = "CaptureIndexing" 8064 8065 def load(self, name: str, index: sympy.Expr): 8066 index = add_index(index, "reads", name) 8067 return self._inner.load(name, index) 8068 8069 def store(self, name, index, value, mode=None): 8070 index = add_index(index, "writes", name) 8071 return self._inner.store(name, index, value, mode) 8072 8073 def store_reduction(self, name, index, value): 8074 index = add_index(index, "writes", name) 8075 return self._inner.store_reduction(name, index, value) 8076 8077 def reduction(self, dtype, src_dtype, reduction_type, value): 8078 result = self._inner.reduction(dtype, src_dtype, reduction_type, value) 8079 if "welford" in reduction_type: 8080 return tuple(result[i] for i in range(3)) 8081 return result 8082 8083 def index_expr(self, index, dtype): 8084 if isinstance(index, (int, sympy.Integer)): 8085 return self._inner.constant(int(index), dtype) 8086 index = add_index(index, "other") 8087 return self._inner.index_expr(index, dtype) 8088 8089 def check_bounds(self, index, size, lower, upper): 8090 index = add_index(index, "other") 8091 size = add_index(size, "other") 8092 return self._inner.check_bounds(index, size, lower, upper) 8093 8094 def bucketize( 8095 self, 8096 values, 8097 offsets_name: str, 8098 offsets_size: sympy.Expr, 8099 indexing_dtype: torch.dtype, 8100 right: bool, 8101 ): 8102 offsets_size = add_index(offsets_size, "other") 8103 return self._inner.bucketize( 8104 values, offsets_name, offsets_size, indexing_dtype, right 8105 ) 8106 8107 @staticmethod 8108 def masked(mask_proxy, masked_body: Callable[..., Any], other_proxy): 8109 """ 8110 Recursively capture the masked out body in another LoopBodyBlock 8111 """ 8112 8113 subblock: LoopBodyBlock 8114 8115 def shim(mask, other): 8116 return V.ops.masked(mask, subblock, other) 8117 8118 name = self.body.add_submodule(shim, "masked_subblock") 8119 subblock = LoopBodyBlock(self.body, masked_body, []) 8120 self.body.subblocks[name] = subblock 8121 return tracer.create_proxy( 8122 "call_module", name, (mask_proxy, other_proxy), {} 8123 ) 8124 8125 @staticmethod 8126 def scan( 8127 dtype_proxy, 8128 combine_fn: Callable[ 8129 [Tuple[Any, ...], Tuple[Any, ...]], Tuple[Any, ...] 8130 ], 8131 value_proxy, 8132 ): 8133 def shim(dtypes, values): 8134 return V.ops.scan(dtypes, combine_fn, values) 8135 8136 name = self.body.add_submodule(shim, "scan") 8137 result = tracer.create_proxy( 8138 "call_module", 8139 name, 8140 (dtype_proxy, value_proxy), 8141 {}, 8142 ) 8143 # Proxies are iterable, but some methods expect tuples/lists 8144 return tuple(result[i] for i in range(len(value_proxy))) 8145 8146 def frexp(self, value_proxy): 8147 result = self._inner.frexp(value_proxy) 8148 # Proxies are iterable, but some methods expect tuples/lists 8149 return (result[0], result[1]) 8150 8151 @staticmethod 8152 def indirect_indexing(index_proxy, size, check=True): 8153 """ 8154 Flow data from tensors into indexing formulas. 8155 Introduce a call_module to update the indexing. 8156 """ 8157 8158 var = self.body.add_indirect(size) 8159 8160 def set_indirect(new_var): 8161 self.body.replace_indirect( 8162 var, V.ops.indirect_indexing(new_var, size, check) 8163 ) 8164 8165 tracer.create_proxy( 8166 "call_module", 8167 self.body.add_submodule(set_indirect, f"set_{var}"), 8168 (index_proxy,), 8169 {}, 8170 ) 8171 return var 8172 8173 @staticmethod 8174 def output(result): 8175 tracer.create_proxy("output", "output", (result,), {}) 8176 8177 tracer = torch.fx.Tracer() 8178 tracer.graph = torch.fx.Graph(tracer_cls=tracer.__class__) 8179 proxy_ops = tracer.create_proxy("placeholder", "ops", (), {}) 8180 8181 from .index_propagation import IndexPropagation 8182 from .sizevars import SimplifyIndexing 8183 8184 handler: Any = SimplifyIndexing( 8185 CaptureIndexing(proxy_ops), self.body.var_ranges 8186 ) 8187 if config.constant_and_index_propagation: 8188 handler = IndexPropagation(handler, self.body.var_ranges) 8189 8190 with V.set_ops_handler(handler): 8191 # This indirection is just a cute way to get IndexPropagation to 8192 # unwrap the return value. 8193 ops.output(fn(*args)) 8194 self.graph = tracer.graph 8195 8196 def __call__(self): 8197 graph = self.graph 8198 submodules = self.body.submodules 8199 8200 return InterpreterShim(graph, submodules).run(V.get_ops_handler()) 8201 8202 def debug_str(self, name="block"): 8203 code = torch.fx.GraphModule(self.body.submodules, self.graph).code 8204 return re.sub( 8205 # strip `; del var0` suffixes to make output prettier 8206 r";[^\n]*", 8207 "", 8208 code.strip().replace("def forward(", f"def {name}("), 8209 ) 8210 8211 8212class _CollectiveKernel(FallbackKernel): 8213 def should_allocate(self): 8214 return False 8215 8216 def has_side_effects(self): 8217 return True 8218 8219 # This is identical to FallbackKernel.set_cpp_kernel(), minus the 8220 # part that checks against input aliasing and mutation. 8221 def set_cpp_kernel(self, kernel): 8222 from .codegen.wrapper import get_cpp_op_schema 8223 8224 self.cpp_kernel_name = kernel._schema.name 8225 self.cpp_kernel_overload_name = kernel._schema.overload_name 8226 self.cpp_kernel_key = f"{self.cpp_kernel_name.replace('::', '_')}_{self.cpp_kernel_overload_name}" # type: ignore[union-attr] 8227 8228 self.cpp_op_schema = get_cpp_op_schema(kernel) 8229 self.ordered_kwargs_for_cpp_kernel = [ 8230 x.name for x in kernel._schema.arguments if x.kwarg_only 8231 ] 8232 8233 # NOTE: [In-Place Collective Safety] 8234 # Between the initiation and completion of an in-place collective, the 8235 # input buffers are subject to both volatile reads and volatile writes. 8236 # They must not be read, written to or reused by another kernel. To ensure 8237 # the constraints, we model collective -> wait_tensor as as two-step 8238 # mutation of the input buffers. 8239 @classmethod 8240 def create_inplace( 8241 cls, kernel, inputs: Union[TensorBox, List[TensorBox]], *args, **kwargs 8242 ) -> None: 8243 cpp_kernel_name = kernel._name 8244 python_kernel_name = cpp_kernel_name.replace("::", ".") 8245 with V.graph.fake_mode: 8246 ( 8247 example_output, 8248 tensor_args, 8249 non_tensor_args, 8250 unflatten_args, 8251 unbacked_bindings, 8252 ) = cls.process_kernel(kernel, inputs, *args, **kwargs) 8253 assert not unbacked_bindings, f"{kernel} {unbacked_bindings}" 8254 for tensor_arg in tensor_args: 8255 tensor_arg.realize() 8256 8257 packed = cls( 8258 NoneLayout(tensor_args[0].get_device()), 8259 kernel, 8260 tensor_args, 8261 non_tensor_args, 8262 unflatten_args, 8263 ) 8264 packed.cpp_kernel_name = cpp_kernel_name 8265 packed.python_kernel_name = python_kernel_name 8266 8267 mark_node_as_mutating(packed, *pytree.tree_leaves(inputs)) 8268 8269 # NOTE: [Out-of-Place Collective Safety] 8270 # Between the initiation and completion of an out-of-place collective: 8271 # 8272 # Input buffers: 8273 # - Are subject to volatile reads 8274 # - Can be read by another kernel 8275 # - Must not be written to or reused by another kernel 8276 # 8277 # Output buffers: 8278 # - Are subject to volatile writes 8279 # - Must not be read, written to or reused by another kernel 8280 # 8281 # To ensure the safety of input buffers without sacrificing read 8282 # availability, we add input buffers as read deps of wait_tensor kernels. 8283 # 8284 # To ensure the safety of output buffers, we model wait_tensor as a 8285 # mutation to the output buffer. Note we also assumes the user program being 8286 # correct and the output buffer is not consumed by kernels other than 8287 # wait_tensor. 8288 # 8289 # TODO(yifu): add a pre-grad pass to validate the correctness of collective 8290 # usage in the user program. 8291 @classmethod 8292 def create_out_of_place( 8293 cls, kernel, inputs: Union[TensorBox, List[TensorBox]], *args, **kwargs 8294 ): 8295 cpp_kernel_name = kernel._name 8296 python_kernel_name = cpp_kernel_name.replace("::", ".") 8297 with V.graph.fake_mode: 8298 ( 8299 example_output, 8300 tensor_args, 8301 non_tensor_args, 8302 unflatten_args, 8303 unbacked_bindings, 8304 ) = cls.process_kernel(kernel, inputs, *args, **kwargs) 8305 assert not unbacked_bindings, f"{kernel}, {unbacked_bindings}" 8306 for tensor_arg in tensor_args: 8307 tensor_arg.realize() 8308 8309 if isinstance(example_output, list): 8310 device = cls.find_device(tensor_args, example_output) 8311 packed = cls( 8312 MultiOutputLayout(device), 8313 kernel, 8314 tensor_args, 8315 non_tensor_args, 8316 unflatten_args, 8317 ) 8318 packed.cpp_kernel_name = cpp_kernel_name 8319 packed.python_kernel_name = python_kernel_name 8320 packed.outputs = [ 8321 MultiOutput( 8322 cls.tensor_to_layout(tensor), 8323 packed, 8324 [(list, i)], 8325 ) 8326 for i, tensor in enumerate(example_output) 8327 ] 8328 return packed.outputs 8329 else: 8330 packed = cls( 8331 cls.tensor_to_layout(example_output), 8332 kernel, 8333 tensor_args, 8334 non_tensor_args, 8335 unflatten_args, 8336 ) 8337 packed.cpp_kernel_name = cpp_kernel_name 8338 packed.python_kernel_name = python_kernel_name 8339 packed.outputs = [packed] 8340 return packed 8341 8342 8343class _WaitKernel(_CollectiveKernel): 8344 def get_volatile_reads(self): 8345 inp = self.inputs[0] 8346 if isinstance(inp, _CollectiveKernel): 8347 # Out-of-place single-output 8348 return [inp.inputs[0]] 8349 elif isinstance(inp, MultiOutput): 8350 # This can be two things: 8351 # 1. Out-of-place multi-output coll 8352 # 2. In-place coll with inputs coming from another MultiOutput 8353 coll = inp.inputs[0] 8354 # Case 1 8355 if isinstance(coll, _CollectiveKernel): 8356 _, idx = inp.indices[0] 8357 return [coll.inputs[idx]] 8358 # Case 2 8359 return [] 8360 else: 8361 # In-place requires no additional deps handling for volatile 8362 # reads since the inputs are mutated. 8363 return [] 8364 8365 @classmethod 8366 def create_wait(cls, kernel, inp: TensorBox) -> None: 8367 with V.graph.fake_mode: 8368 ( 8369 example_output, 8370 tensor_args, 8371 non_tensor_args, 8372 unflatten_args, 8373 unbacked_bindings, 8374 ) = cls.process_kernel(kernel, inp) 8375 assert not unbacked_bindings, f"{kernel} {unbacked_bindings}" 8376 packed = cls( 8377 NoneLayout(inp.get_device()), 8378 kernel, 8379 tensor_args, 8380 non_tensor_args, 8381 unflatten_args, 8382 ) 8383 8384 mark_node_as_mutating(packed, inp) 8385 8386 def get_read_writes(self): 8387 read_writes = super().get_read_writes() 8388 # See [Out-of-Place Collective Safety]. 8389 volatile_reads = self.get_volatile_reads() 8390 for vr in volatile_reads: 8391 read_writes.reads.add(dependencies.StarDep(vr.get_name())) 8392 return read_writes 8393 8394 8395# NB: recursive structure here reflects val_to_arg_str, avoid 8396# calling free_unbacked_symbols on "exotic" types that don't get pexpr 8397# treatment 8398def maybe_free_unbacked_symbols(s): 8399 if isinstance(s, (SymTypes, sympy.Expr)): 8400 # This branch should be impossible in return position 8401 return free_unbacked_symbols(s) 8402 elif isinstance(s, (tuple, list)): 8403 r = set() 8404 for t in s: 8405 r |= maybe_free_unbacked_symbols(t) 8406 return r 8407 elif isinstance(s, torch.Tensor): 8408 # This branch is impossible in constant-args position 8409 return free_unbacked_symbols(s) 8410 else: 8411 return set() 8412