1# mypy: allow-untyped-defs 2import builtins 3import contextlib 4import functools 5import inspect 6import itertools 7import json 8import logging 9import math 10import operator 11import os 12import sys 13import textwrap 14import time 15from collections import namedtuple 16from concurrent.futures import as_completed, ThreadPoolExecutor 17from io import StringIO 18from typing import Any, Callable, Dict, List, Optional, Tuple, Union 19from unittest.mock import patch 20 21import sympy 22from filelock import FileLock 23 24import torch 25import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools 26from torch._dynamo.testing import rand_strided 27from torch._dynamo.utils import counters, identity, preserve_rng_state 28 29from . import config, ir 30from .autotune_process import TensorMeta, TritonBenchmarkRequest 31from .codecache import code_hash, PersistentCache, PyCodeCache 32from .codegen.common import IndentedBuffer, KernelTemplate 33from .codegen.triton import ( 34 gen_common_triton_imports, 35 texpr, 36 TritonKernel, 37 TritonPrinter, 38 TritonScheduling, 39) 40from .codegen.triton_utils import config_of, signature_to_meta 41from .exc import CUDACompileError 42from .ir import ChoiceCaller, PrimitiveInfoType 43from .runtime.benchmarking import benchmarker 44from .runtime.hints import DeviceProperties 45from .utils import ( 46 FakeIndentedBuffer, 47 get_dtype_size, 48 Placeholder, 49 restore_stdout_stderr, 50 sympy_dot, 51 sympy_index_symbol, 52 sympy_product, 53 unique, 54) 55from .virtualized import V 56 57 58log = logging.getLogger(__name__) 59 60# correctness checks struggle with fp16/tf32 61VERIFY: Dict[str, Any] = {} 62PRINT_AUTOTUNE = True 63DEBUG = False 64 65 66class KernelNamespace: 67 pass 68 69 70# these objects are imported from the generated wrapper code 71extern_kernels = KernelNamespace() 72 73 74class PartialRender: 75 """ 76 Some parts of a template need to be generated at the end, but 77 inserted into the template at the start. This allows doing a bunch 78 of replacements after the initial render. 79 """ 80 81 def __init__(self, code, replacement_hooks) -> None: 82 super().__init__() 83 self.code = code 84 self.replacement_hooks = replacement_hooks 85 86 def finalize_hook(self, hook_key: str, strict=True) -> None: 87 if hook_key not in self.replacement_hooks: 88 if strict: 89 raise RuntimeError( 90 f"{hook_key} not registered in self.replacement_hooks" 91 ) 92 else: 93 return 94 assert ( 95 self.replacement_hooks[hook_key] is not None 96 ), "hook_key can only be called once" 97 self.code = self.code.replace(hook_key, self.replacement_hooks[hook_key]()) 98 self.replacement_hooks[hook_key] = None 99 100 def finalize_all(self) -> str: 101 for key, fn in self.replacement_hooks.items(): 102 self.code = self.code.replace(key, fn()) 103 return self.code 104 105 106# This is used to store info needed for lowering each subgraph in triton 107# templates 108SubgraphInfo = namedtuple( 109 "SubgraphInfo", 110 [ 111 "body", 112 "template_mask", 113 "template_out", 114 ], 115) 116 117 118class TritonTemplateKernel(TritonKernel): 119 def __init__( 120 self, 121 kernel_name, 122 input_nodes, 123 output_node, 124 defines, 125 num_stages, 126 num_warps, 127 grid_fn, 128 meta, 129 call_sizes, 130 use_jit=False, 131 prefix_args=0, 132 suffix_args=0, 133 epilogue_fn=identity, 134 subgraphs: Optional[List[ir.ComputedBuffer]] = None, 135 *, 136 index_dtype, 137 ) -> None: 138 super().__init__( 139 sympy_product(output_node.get_size()), 140 sympy.Integer(1), 141 index_dtype=index_dtype, 142 ) 143 self.input_nodes = input_nodes 144 self.output_node = output_node 145 self.named_input_nodes = {} # type: ignore[var-annotated] 146 self.defines = defines 147 self.kernel_name = kernel_name 148 self.use_jit = use_jit 149 self.num_stages = num_stages 150 self.num_warps = num_warps 151 self.grid_fn = grid_fn 152 self.meta = meta 153 self.call_sizes = call_sizes 154 # for templates with fixed epilogues 155 self.prefix_args = prefix_args 156 self.suffix_args = suffix_args 157 self.epilogue_fn = epilogue_fn 158 self.render_hooks = {} # type: ignore[var-annotated] 159 self.triton_meta: Optional[Dict[str, object]] = None 160 # For Templated Attention this can be a list of ir.Subgraph 161 self.subgraphs: Optional[List[ir.ComputedBuffer]] = subgraphs 162 163 # The following attributes (body, template_mask, output_val) are all 164 # used for triton kernel codegen. 165 # They are swapped onto the TritonTemplateKernel object by 166 # `set_subgraph_body` 167 self.subgraph_bodies: Dict[str, SubgraphInfo] = {} 168 169 self.body: IndentedBuffer = FakeIndentedBuffer() 170 self.template_mask: Optional[str] = None 171 self.template_out: Optional[str] = None 172 173 @contextlib.contextmanager 174 def set_subgraph_body(self, body_name: str): 175 old_body, old_mask, old_out = self.body, self.template_mask, self.template_out 176 assert body_name in self.subgraph_bodies, body_name 177 self.body, self.template_mask, self.template_out = self.subgraph_bodies[ 178 body_name 179 ] 180 yield 181 self.subgraph_bodies[body_name] = SubgraphInfo( 182 self.body, self.template_mask, self.template_out 183 ) 184 self.body, self.template_mask, self.template_out = old_body, old_mask, old_out 185 186 @contextlib.contextmanager 187 def create_subgraph_body(self, body_name: str): 188 assert body_name not in self.subgraph_bodies 189 self.subgraph_bodies[body_name] = SubgraphInfo(IndentedBuffer(), None, None) 190 with self.set_subgraph_body(body_name): 191 yield 192 193 def need_numel_args(self): 194 return False 195 196 def estimate_kernel_num_bytes(self): 197 """ 198 Estimate the total number of bytes this kernel takes. 199 For in/out nodes, sizes are counted twice: once for reading and 200 once for writing. 201 """ 202 ninplace_args = len(unique(self.args.inplace_buffers.values())) 203 num_bytes = [] 204 for i, inp in enumerate(itertools.chain(self.input_nodes, (self.output_node,))): 205 size = V.graph.sizevars.size_hints(inp.get_size()) 206 numel = functools.reduce(operator.mul, size, 1) 207 dtype_size = get_dtype_size(inp.get_dtype()) 208 num_bytes.append(numel * dtype_size * (1 + int(i < ninplace_args))) 209 return sum(num_bytes) 210 211 def jit_lines(self): 212 if self.use_jit: 213 return "@triton.jit" 214 215 argdefs, _, signature, _ = self.args.python_argdefs() 216 triton_meta = { 217 "signature": signature_to_meta(signature, size_dtype=self.index_dtype), 218 "device": DeviceProperties.create(self.output_node.get_device()), 219 "constants": {}, 220 } 221 triton_meta["configs"] = [config_of(signature)] 222 for arg_num in triton_meta["configs"][0].equal_to_1: # type: ignore[index] 223 triton_meta["constants"][arg_num] = 1 # type: ignore[index] 224 matrix_instr_nonkdim = self.meta.get("matrix_instr_nonkdim", 0) 225 if matrix_instr_nonkdim != 0: 226 triton_meta["matrix_instr_nonkdim"] = matrix_instr_nonkdim 227 228 self.triton_meta = triton_meta 229 230 inductor_meta = { 231 "kernel_name": str(Placeholder.DESCRIPTIVE_NAME), 232 **TritonKernel.inductor_meta_common(), 233 } 234 if config.profile_bandwidth or config.benchmark_kernel: 235 num_gb = self.estimate_kernel_num_bytes() / 1e9 236 inductor_meta["kernel_num_gb"] = num_gb 237 return f""" 238 @triton_heuristics.template( 239 num_stages={self.num_stages}, 240 num_warps={self.num_warps}, 241 triton_meta={triton_meta!r}, 242 inductor_meta={inductor_meta!r}, 243 ) 244 @triton.jit 245 """ 246 247 def gen_argdefs(self): 248 def hook(): 249 # python_argdefs() cannot be run until after the rest of the template lazily adds more args 250 arg_defs, *_ = self.args.python_argdefs() 251 return f"{', '.join(arg_defs)}" 252 253 self.render_hooks["<ARGDEFS>"] = hook 254 return "<ARGDEFS>" 255 256 def gen_defines(self): 257 return self.defines 258 259 def def_kernel(self, *argnames): 260 """ 261 Hook called from template code to generate function def and 262 needed args. 263 """ 264 assert all(isinstance(x, str) for x in argnames) 265 renames = IndentedBuffer(initial_indent=1) 266 267 named_args = self.input_nodes[ 268 self.prefix_args : len(self.input_nodes) - self.suffix_args 269 ] 270 271 assert len(argnames) == len(named_args), ( 272 len(argnames), 273 len(named_args), 274 self.prefix_args, 275 len(self.input_nodes), 276 ) 277 278 for input_node in self.input_nodes[: self.prefix_args]: 279 # get args in correct order 280 self.args.input(input_node.get_name()) 281 282 for name, input_node in zip(argnames, named_args): 283 arg_name = f"arg_{name}" 284 self.named_input_nodes[name] = input_node 285 self.args.input_buffers[input_node.get_name()] = arg_name 286 287 # The args may be duplicated, so renaming must be after args are de-duplicated. 288 for name in argnames: 289 input_node = self.named_input_nodes[name] 290 arg_name = self.args.input_buffers[input_node.get_name()] 291 if input_node.get_layout().offset == 0: 292 renames.writeline(f"{name} = {arg_name}") 293 else: 294 offset = texpr(self.rename_indexing(input_node.get_layout().offset)) 295 renames.writeline(f"{name} = {arg_name} + {offset}") 296 297 for input_node in self.input_nodes[len(self.input_nodes) - self.suffix_args :]: 298 # get args in correct order 299 self.args.input(input_node.get_name()) 300 301 def hook(): 302 # python_argdefs() cannot be run until after the rest of the template lazily adds more args 303 arg_defs, *_ = self.args.python_argdefs() 304 code = IndentedBuffer() 305 code.splice(gen_common_triton_imports()) 306 code.splice(self.jit_lines()) 307 code.writeline(f"def {self.kernel_name}({', '.join(arg_defs)}):") 308 with code.indent(): 309 code.splice(self.defines) 310 code.splice(renames.getvalue()) 311 return code.getvalue() 312 313 assert "<DEF_KERNEL>" not in self.render_hooks 314 self.render_hooks["<DEF_KERNEL>"] = hook 315 return "<DEF_KERNEL>" 316 317 def size(self, name: str, index: int): 318 """ 319 Hook called from template code to get the size of an arg. 320 Will add needed args to pass it in if it is dynamic. 321 """ 322 assert isinstance(index, int) 323 if name is None: 324 val = self.output_node.get_size()[index] 325 else: 326 assert isinstance(name, str) 327 val = self.named_input_nodes[name].get_size()[index] 328 return texpr(self.rename_indexing(val)) 329 330 def stride(self, name, index=None): 331 """ 332 Hook called from template code to get the stride of an arg. 333 Will add needed args to pass it in if it is dynamic. 334 """ 335 if name is None: 336 val = self.output_node.get_stride() 337 else: 338 assert isinstance(name, str) 339 val = self.named_input_nodes[name].get_stride() 340 341 if isinstance(index, int): 342 return texpr(self.rename_indexing(val[index])) 343 else: 344 return ", ".join([texpr(self.rename_indexing(i)) for i in val]) 345 346 def modification( 347 self, subgraph_number: int, output_name: str, **fixed_inputs 348 ) -> str: 349 """This creates a modification function for a subgraph. 350 To use this inside a template, the first argument should specify which subgraph to codegen for 351 352 Args: 353 subgraph_number (int): The index of the subgraph in self.subgraphs 354 """ 355 num = 0 356 while f"mod_{subgraph_number}_{num}" in self.subgraph_bodies: 357 num += 1 358 with self.create_subgraph_body(f"mod_{subgraph_number}_{num}"): 359 assert isinstance(subgraph_number, int) 360 assert isinstance(self.subgraphs, list) 361 assert ( 362 self.body.getvalue() == "" 363 ), "Body should be clear before adding a modification" 364 assert subgraph_number < len( 365 self.subgraphs 366 ), f"Invalid subgraph number provided to create_modification, {subgraph_number} must be < {len(self.subgraphs)}" 367 368 subgraph = self.subgraphs[subgraph_number] 369 370 def add_input(name): 371 return self.args.input(name) 372 373 name = f"PlaceholderSubstitution_{subgraph_number}" 374 375 class PlaceholderSubstitution(V.WrapperHandler): # type: ignore[name-defined] 376 self.name = name 377 378 def load(self, name: str, index: sympy.Expr): 379 if name not in fixed_inputs: 380 # If it's not a fixed input, it's a load from a captured 381 # tensor 382 var = add_input(name) 383 return f"tl.load({var} + {index})" 384 385 return f"({fixed_inputs[name]})" 386 387 def indirect_indexing(self, index_var, size, check, wrap_neg=True): 388 return sympy_index_symbol(str(index_var)) 389 390 with V.set_ops_handler(PlaceholderSubstitution(V.ops)): 391 assert isinstance( 392 subgraph, ir.ComputedBuffer 393 ), f"Expected the subgraph to be a ComputedBuffer, got {type(subgraph)}" 394 if isinstance(subgraph.data, ir.InputBuffer): 395 out = subgraph.data.make_loader()(()) 396 else: 397 out = subgraph.data.inner_fn(()) 398 399 self.codegen_body() 400 self.body.writeline(f"{output_name} = {out.value}") 401 402 body_val = self.body.getvalue() 403 self.cse.invalidate(set()) # type: ignore[arg-type] 404 return body_val 405 406 def store_output( 407 self, 408 indices: Union[List[Any], Tuple[Any]], 409 val: str, 410 mask: Optional[str] = None, 411 indent_width: int = 4, 412 ): 413 """Stores the final output and appends any epilogue fusions if the buffer hasn't been optimized away. 414 415 Args: 416 indices (Union[List, Tuple]): The index for each dimension of the output. The dot product of 417 these indices and output strides must match `val`. 418 val (str): The value to store. 419 mask (Optional[str]): An optional mask to use for the store operation. If provided, this mask 420 will be applied to the store. 421 indent_width (int): The number of spaces to use for indentation. This is used when the call to 422 store_output is indented in the kernel definition. 423 """ 424 with self.create_subgraph_body("<STORE_OUTPUT>"): 425 assert isinstance(indices, (list, tuple)) 426 assert isinstance(val, str) 427 assert isinstance(mask, (str, type(None))) 428 assert self.template_mask is None 429 indices = list(map(TritonPrinter.paren, indices)) 430 index_symbols = [sympy.Symbol(x, integer=True) for x in indices] 431 lengths = [ 432 V.graph.sizevars.simplify(s) for s in self.output_node.get_size() 433 ] 434 assert len(indices) == len(lengths) 435 436 # glue to make generated code use same indexing from template 437 for name, range_tree_entry in zip( 438 indices, self.range_trees[0].construct_entries(lengths) 439 ): 440 range_tree_entry.set_name(name) 441 contiguous_index = sympy_dot( 442 ir.FlexibleLayout.contiguous_strides(lengths), index_symbols 443 ) 444 contiguous_index = self.rename_indexing(contiguous_index) 445 self.body.writeline("xindex = " + texpr(contiguous_index)) 446 self.range_trees[0].lookup( 447 sympy.Integer(1), sympy_product(lengths) 448 ).set_name("xindex") 449 self.template_mask = mask 450 self.template_out = val 451 self.template_indices = indices 452 output_index = self.output_node.get_layout().make_indexer()(index_symbols) 453 output_index = self.rename_indexing(output_index) 454 if output_index == contiguous_index: 455 output_index = sympy.Symbol("xindex", integer=True) 456 457 epilogue_args = [val] 458 for input_node in itertools.chain( 459 self.input_nodes[: self.prefix_args], 460 self.input_nodes[len(self.input_nodes) - self.suffix_args :], 461 ): 462 input_node.freeze_layout() 463 epilogue_args.append(input_node.make_loader()(index_symbols)) 464 465 V.ops.store( 466 self.output_node.get_name(), 467 output_index, 468 self.epilogue_fn(*epilogue_args), 469 ) 470 self.codegen_body() 471 472 def hook(): 473 # more stuff might have been added since the codegen_body above 474 self.codegen_body() 475 476 return textwrap.indent(self.body.getvalue(), " " * indent_width).strip() 477 478 assert "<STORE_OUTPUT>" not in self.render_hooks 479 self.render_hooks["<STORE_OUTPUT>"] = hook 480 return "<STORE_OUTPUT>" 481 482 def render(self, template, kwargs): 483 return PartialRender( 484 template.render(**self.template_env(), **kwargs), 485 self.render_hooks, 486 ) 487 488 def make_load(self, name, indices, mask): 489 """ 490 Optional helper called from template code to generate the code 491 needed to load from an tensor. 492 """ 493 assert isinstance(indices, (list, tuple)) 494 assert isinstance(name, str) 495 assert isinstance(mask, str) 496 stride = self.named_input_nodes[name].get_stride() 497 indices = list(map(TritonPrinter.paren, indices)) 498 assert len(indices) == len(stride) 499 index = " + ".join( 500 f"{texpr(self.rename_indexing(s))} * {i}" for s, i in zip(stride, indices) 501 ) 502 return f"tl.load({name} + ({index}), {mask}, other=0.0)" 503 504 def template_env(self): 505 """ 506 Generate the namespace visible in the template. 507 """ 508 return { 509 fn.__name__: fn 510 for fn in [ 511 self.def_kernel, 512 self.size, 513 self.stride, 514 self.store_output, 515 self.make_load, 516 self.modification, 517 self.gen_argdefs, 518 self.gen_defines, 519 ] 520 } 521 522 def indexing( 523 self, 524 index: sympy.Expr, 525 *, 526 dense_indexing=False, 527 copy_shape=None, 528 override_mask=None, 529 block_ptr=False, 530 ): 531 """ 532 Override the default indexing to use our custom mask and force 533 dense indexing. 534 """ 535 return super().indexing( 536 index, 537 dense_indexing=False, 538 # We pass template_out as the shape to broadcast the indexing to as 539 # the mask might be broadcast to the output shape 540 copy_shape=self.template_out, 541 override_mask=self.template_mask, 542 block_ptr=block_ptr, 543 ) 544 545 def codegen_range_tree(self): 546 pass # ignore default codegen 547 548 def call_kernel(self, name: str, node: Optional[ir.IRNode] = None): 549 wrapper = V.graph.wrapper_code 550 _, call_args, _, arg_types = self.args.python_argdefs() 551 if V.graph.cpp_wrapper: 552 # In the cpp_wrapper case, we have to compute CUDA launch grid at runtime 553 # if any dynamic dimension is involved. We rely on the Python version 554 # of the grid function to generate those grid configs, which may contain 555 # symbolic values. The wrapper will use cexpr to print out C++ code 556 # appropriately for the grid configs. 557 grid = self.call_sizes + [self.meta] 558 wrapper.generate_kernel_call( 559 name, 560 call_args, 561 grid=self.grid_fn(*grid), 562 arg_types=arg_types, 563 triton_meta=self.triton_meta, 564 ) 565 else: 566 wrapper.add_import_once(f"import {self.grid_fn.__module__}") 567 meta = wrapper.add_meta_once(self.meta) 568 grid = self.call_sizes + [meta] 569 wrapper.generate_kernel_call( 570 name, 571 call_args, 572 grid=grid, 573 grid_fn=f"{self.grid_fn.__module__}.{self.grid_fn.__name__}", 574 arg_types=arg_types, 575 triton_meta=self.triton_meta, 576 ) 577 578 579@functools.lru_cache(None) 580def _jinja2_env(): 581 try: 582 import jinja2 583 584 return jinja2.Environment( 585 undefined=jinja2.StrictUndefined, 586 ) 587 except ImportError: 588 return None 589 590 591class TritonTemplate(KernelTemplate): 592 index_counter = itertools.count() 593 all_templates: Dict[str, "TritonTemplate"] = {} 594 595 def __init__(self, name: str, grid: Any, source: str, debug=False) -> None: 596 super().__init__(name) 597 self.grid = grid 598 self.template = self._template_from_string(source) 599 assert name not in self.all_templates, "duplicate template name" 600 self.all_templates[name] = self 601 self.debug = debug 602 603 def generate( # type: ignore[override] 604 self, 605 input_nodes, 606 layout, 607 num_stages, 608 num_warps, 609 prefix_args=0, 610 suffix_args=0, 611 epilogue_fn=identity, 612 subgraphs=None, 613 mutated_inputs=None, 614 call_sizes=None, 615 **kwargs, 616 ): 617 """This function generates a TritonTemplateCaller 618 619 Args: 620 input_nodes: List of input nodes 621 layout: Output layout 622 num_stages: Number of stages for triton launch 623 num_warps: Number of warps for triton launch 624 prefix_args: Number of input nodes to be passed as arguments 625 suffix_args: Number of input nodes to be passed as arguments 626 epilogue_fn: Optional epilogue function to be called on the output 627 subgraphs: Optional subgraphs to be passed as arguments, these will be inlined 628 into the triton template string 629 mutated_inputs: Optional list of input nodes that are mutated by the kernel, this is helpful 630 if you need to return multiple outputs. You can pass them as inputs and mark them as 631 being mutated by the kernel. 632 """ 633 assert self.template, "requires jinja2" 634 defines = StringIO() 635 for name, val in kwargs.items(): 636 defines.write(f"{name} : tl.constexpr = {val}\n") 637 defines = defines.getvalue() 638 639 fake_out = ir.Buffer("buf_out", layout) 640 kernel_name = f"triton_{self.name}" 641 642 numel = sympy_product(layout.size) 643 buffers = itertools.chain(input_nodes, (fake_out,)) 644 if not TritonScheduling.can_use_32bit_indexing(numel, buffers): 645 raise NotImplementedError( 646 "64-bit indexing is not yet implemented for triton templates" 647 ) 648 649 if call_sizes is None: 650 call_sizes = layout.size 651 652 kernel_options = dict( 653 input_nodes=input_nodes, 654 defines=defines, 655 num_stages=num_stages, 656 num_warps=num_warps, 657 grid_fn=self.grid, 658 meta=kwargs, 659 call_sizes=call_sizes, 660 prefix_args=prefix_args, 661 suffix_args=suffix_args, 662 epilogue_fn=epilogue_fn, 663 index_dtype="tl.int32", 664 subgraphs=subgraphs, 665 ) 666 667 with patch.object( 668 V.graph, "get_dtype", self._fake_get_dtype(fake_out) 669 ), TritonTemplateKernel( 670 kernel_name=kernel_name, 671 output_node=fake_out, 672 use_jit=False, 673 **kernel_options, 674 ) as kernel: 675 try: 676 template = kernel.render(self.template, kwargs) 677 with kernel.set_subgraph_body("<STORE_OUTPUT>"): 678 code = template.finalize_all() 679 except ZeroDivisionError: 680 # TODO(nmacchioni): fix sympy division by zero 681 return None 682 if self.debug: 683 print("Generated Code:\n", code) 684 extra = ( 685 "-".join( 686 [ 687 *[ 688 f"{kwarg}={repr(kwargs[kwarg])}" 689 for kwarg in sorted(kwargs.keys()) 690 ], 691 f"num_stages={num_stages}", 692 f"num_warps={num_warps}", 693 ] 694 ) 695 + "-" 696 ) 697 mod = PyCodeCache.load(code, extra) 698 699 input_call_args = tuple(kernel.args.input_buffers.keys()) 700 output_call_args = tuple(kernel.args.output_buffers.keys()) 701 702 # We expect the input_buffer order to be [*input_nodes, *captured_buffers] 703 expected_input_args = tuple(unique(x.get_name() for x in input_nodes)) 704 expected_output_args = (fake_out.get_name(),) 705 assert input_call_args[: len(expected_input_args)] == expected_input_args, ( 706 input_call_args, 707 expected_input_args, 708 ) 709 assert output_call_args == expected_output_args, ( 710 output_call_args, 711 expected_output_args, 712 ) 713 714 full_input_nodes = tuple([V.graph.get_buffer(k) for k in input_call_args]) 715 extra_args = V.graph.sizevars.size_hints( 716 map(sympy.expand, tuple(kernel.args.sizevars.keys())), 717 fallback=config.unbacked_symint_fallback, 718 ) 719 720 kernel_hash_name = f"triton_{self.name}_{next(self.index_counter)}" 721 722 def make_kernel_render(out_node): 723 kernel = TritonTemplateKernel( 724 kernel_name=str(Placeholder.KERNEL_NAME), 725 output_node=out_node, 726 use_jit=False, 727 **kernel_options, 728 ) 729 render = functools.partial( 730 kernel.render, 731 self.template, 732 kwargs, 733 ) 734 return kernel, render 735 736 # create the BenchmarkRequest 737 assert mod.__file__ is not None 738 grid = self.grid( 739 *V.graph.sizevars.size_hints( 740 call_sizes, 741 fallback=config.unbacked_symint_fallback, 742 ), 743 kwargs, 744 ) 745 bmreq = TritonBenchmarkRequest( 746 module_path=mod.__file__, 747 module_cache_key=mod.key, 748 kernel_name=kernel_name, 749 grid=grid, 750 extra_args=extra_args, 751 num_stages=num_stages, 752 num_warps=num_warps, 753 matrix_instr_nonkdim=kwargs.get("matrix_instr_nonkdim", 0), 754 input_tensor_meta=TensorMeta.from_irnodes(full_input_nodes), # type: ignore[arg-type] 755 output_tensor_meta=TensorMeta.from_irnodes(layout), 756 ) 757 758 return TritonTemplateCaller( 759 kernel_hash_name, 760 full_input_nodes, 761 layout, 762 make_kernel_render, 763 extra.strip("-").replace("-", ", "), 764 bmreq, 765 log_info={ 766 "tile_shape": str( 767 ( 768 kwargs.get("BLOCK_M", -1), 769 kwargs.get("BLOCK_K", -1), 770 kwargs.get("BLOCK_N", -1), 771 ) 772 ), 773 "num_stages": num_stages, 774 "num_warps": num_warps, 775 "allow_tf32": str(kwargs.get("ALLOW_TF32", None)), 776 "acc_type": str(kwargs.get("ACC_TYPE", None)), 777 }, 778 mutated_inputs=mutated_inputs, 779 ) 780 781 782class ExternKernelChoice: 783 def __init__( 784 self, 785 kernel, 786 cpp_kernel=None, 787 *, 788 name=None, 789 has_out_variant=True, 790 op_overload=None, 791 use_fallback_kernel=False, 792 kernel_creator=None, 793 ) -> None: 794 super().__init__() 795 name = name or kernel.__name__ 796 assert callable(kernel) 797 assert not hasattr(extern_kernels, name), f"duplicate extern kernel: {name}" 798 self.name = name 799 self.cpp_kernel_name = cpp_kernel 800 self.has_out_variant = has_out_variant 801 setattr(extern_kernels, name, kernel) 802 self.op_overload = op_overload 803 self.use_fallback_kernel = use_fallback_kernel 804 self.kernel_creator = kernel_creator 805 806 def to_callable(self): 807 return getattr(extern_kernels, self.name) 808 809 def call_name(self): 810 return f"extern_kernels.{self.name}" 811 812 @functools.lru_cache(None) # noqa: B019 813 def hash_key(self): 814 fn = self.to_callable() 815 parts = [ 816 self.name, 817 getattr(fn, "__name__", ""), 818 getattr(fn, "__module__", ""), 819 ] 820 try: 821 parts.append(inspect.getsource(fn)) 822 except Exception: 823 pass 824 return code_hash("-".join(parts)) 825 826 def bind( 827 self, 828 input_nodes, 829 layout, 830 ordered_kwargs_for_cpp_kernel=(), 831 **kwargs, 832 ): 833 self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel 834 return ExternKernelCaller( 835 self, input_nodes, layout, kwargs, has_out_variant=self.has_out_variant 836 ) 837 838 839class TritonTemplateCaller(ir.TritonTemplateCallerBase): 840 def __init__( 841 self, 842 name, 843 input_nodes, 844 layout, 845 make_kernel_render, 846 debug_extra, 847 bmreq, 848 log_info: Optional[ 849 Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]] 850 ] = None, 851 mutated_inputs=None, 852 ) -> None: 853 super().__init__(name, input_nodes, layout) 854 self.make_kernel_render = make_kernel_render 855 self.debug_extra = debug_extra 856 self.bmreq: TritonBenchmarkRequest = bmreq 857 if log_info is None: 858 log_info = {} 859 self.log_info: Dict[str, Any] = log_info 860 self.log_info.update( 861 { 862 "backend": "Triton", 863 "grid": str(self.bmreq.grid), 864 "num_stages": self.bmreq.num_stages, 865 "num_warps": self.bmreq.num_warps, 866 } 867 ) 868 self.mutated_inputs = mutated_inputs 869 870 def benchmark(self, *args, out): 871 assert self.bmreq is not None 872 return self.bmreq.benchmark(*args, output_tensor=out) 873 874 def precompile(self): 875 assert self.bmreq is not None 876 self.bmreq.precompile() 877 878 def __str__(self) -> str: 879 return f"TritonTemplateCaller({self.bmreq.module_path}, {self.debug_extra})" 880 881 def call_name(self): 882 return f"template_kernels.{self.name}" 883 884 def hash_key(self): 885 return "-".join( 886 [ 887 self.name.rsplit("_", 1)[0], 888 self.bmreq.module_cache_key, 889 ] 890 ) 891 892 def output_node(self): 893 return ir.TensorBox.create( 894 ir.TritonTemplateBuffer( 895 layout=self.layout, 896 inputs=self.input_nodes, 897 make_kernel_render=self.make_kernel_render, 898 debug_extra=self.debug_extra, 899 mutated_inputs=self.mutated_inputs, 900 ) 901 ) 902 903 def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]: 904 """Information returned here is logged to the autotune log file when that is enabled.""" 905 return self.log_info 906 907 def get_make_kernel_render(self): 908 return self.make_kernel_render 909 910 def autoheuristic_id(self): 911 type_name = "triton" 912 info = self.info_dict() 913 # TODO(AlnisM): Does tile_shape always exist? 914 tile = info["tile_shape"] 915 tile_vals = eval(tile) # type: ignore[arg-type] 916 BLOCK_M = tile_vals[0] 917 BLOCK_K = tile_vals[1] 918 BLOCK_N = tile_vals[2] 919 num_stages = info["num_stages"] 920 num_warps = info["num_warps"] 921 return f"type={type_name}_BLOCK-M={BLOCK_M}_BLOCK-K={BLOCK_K}_BLOCK-N={BLOCK_N}_numstages={num_stages}_numwarps={num_warps}" 922 923 924class ExternKernelCaller(ChoiceCaller): 925 def __init__( 926 self, 927 choice: ExternKernelChoice, 928 input_nodes, 929 layout, 930 kwargs=None, 931 *, 932 has_out_variant=True, 933 ) -> None: 934 super().__init__(choice.name, input_nodes, layout) 935 self.choice = choice 936 self.kwargs = kwargs or {} 937 self.has_out_variant = has_out_variant 938 939 def __str__(self) -> str: 940 return f"ExternKernelCaller({self.choice.call_name()})" 941 942 def benchmark(self, *args, out): 943 if out.numel() == 0: 944 # no need to run the kerrnel of do benchmarking 945 return 0.0 946 if self.has_out_variant: 947 return super().benchmark(*args, out=out) 948 else: 949 algo = self.to_callable() 950 out_new = algo(*args) 951 torch._C._dynamo.guards.assert_size_stride( 952 out_new, tuple(out.size()), tuple(out.stride()) 953 ) 954 out.copy_(out_new) # for correctness checking 955 return benchmarker.benchmark(algo, args, {}) 956 957 def to_callable(self): 958 fn = self.choice.to_callable() 959 if self.kwargs: 960 return functools.partial(fn, **self.kwargs) 961 else: 962 return fn 963 964 def hash_key(self): 965 return "-".join( 966 [ 967 self.choice.name, 968 *[ 969 f"{kwarg}={repr(self.kwargs[kwarg])}" 970 for kwarg in sorted(self.kwargs.keys()) 971 ], 972 self.choice.hash_key(), 973 ] 974 ) 975 976 def output_node(self): 977 if config.abi_compatible and self.choice.use_fallback_kernel: 978 assert ( 979 self.choice.op_overload is not None 980 ), "Please provide an op_overload to use ir.FallbackKernel" 981 inner = ir.FallbackKernel.create( 982 self.choice.op_overload, *self.input_nodes, **self.kwargs 983 ) 984 elif self.choice.kernel_creator is not None: 985 inner = self.choice.kernel_creator(*self.input_nodes, **self.kwargs) 986 else: 987 cls = ir.ExternKernelOut if self.has_out_variant else ir.ExternKernelAlloc 988 inner = cls( 989 layout=self.layout, 990 inputs=self.input_nodes, 991 python_kernel_name=self.choice.call_name(), 992 cpp_kernel_name=self.choice.cpp_kernel_name, 993 ordered_kwargs_for_cpp_kernel=self.choice.ordered_kwargs_for_cpp_kernel, 994 op_overload=self.choice.op_overload, 995 kwargs=self.kwargs, 996 ) 997 998 return ir.TensorBox.create(inner) 999 1000 def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]: 1001 """Information returned here is logged to the autotune log file when that is enabled.""" 1002 return { 1003 "backend": "extern", 1004 "kernel_call_name": self.choice.call_name(), 1005 } 1006 1007 def autoheuristic_id(self): 1008 return f"extern_{self.choice.name}" 1009 1010 1011@functools.lru_cache(None) 1012def get_mm_log_filename() -> Optional[str]: 1013 mm_file_name = os.environ.get("TORCHINDUCTOR_MM_LOGGING_FILE", None) 1014 if not mm_file_name: 1015 return None 1016 1017 if "json" not in mm_file_name: 1018 mm_file_name = f"{mm_file_name}.json" 1019 1020 return mm_file_name 1021 1022 1023def append_to_log(filename, data): 1024 lock_file = filename.replace(".json", ".lock") 1025 lock = FileLock(lock_file) 1026 with lock: 1027 try: 1028 with open(filename) as f: 1029 log_data = json.load(f) 1030 except (FileNotFoundError, json.JSONDecodeError): 1031 log_data = [] 1032 1033 log_data.append(data) 1034 1035 with open(filename, "w") as f: 1036 json.dump(log_data, f, indent=4) 1037 1038 1039class DataProcessorChoiceCallerWrapper: 1040 def __init__(self, wrapped, preprocessor, postprocessor) -> None: 1041 self._wrapped = wrapped 1042 if preprocessor is not None: 1043 self._preprocessor = preprocessor 1044 else: 1045 self._preprocessor = lambda x, y: (x, y) 1046 if postprocessor is not None: 1047 self._postprocessor = postprocessor 1048 else: 1049 self._postprocessor = lambda x: x 1050 1051 def __getattr__(self, name): 1052 return getattr(self._wrapped, name) 1053 1054 def benchmark(self, *args, out) -> float: 1055 new_args, new_out = self._preprocessor(args, out) 1056 result = self._wrapped.benchmark(*new_args, out=new_out) 1057 new_out = self._postprocessor(new_out) 1058 if out is not new_out: 1059 out.copy_(new_out) 1060 return result 1061 1062 def output_node(self) -> ir.TensorBox: 1063 result = self._wrapped.output_node() 1064 return self._postprocessor(result) 1065 1066 def __repr__(self) -> str: 1067 return f"DataProcessorChoiceCallerWrapper({self._wrapped})" 1068 1069 1070class DataProcessorTemplateWrapper: 1071 """ 1072 A wrapper class for a kernel template. 1073 1074 This class together with `DataProcessorChoiceCallerWrapper` provides a convenient way to 1075 preprocess and postprocess data before and after using the wrapped template. A typical 1076 usage is to reorder or filter the input nodes in order to match the expected input of other 1077 kernel choices like a ATen kernel. A more complicated usage is to prepack the weights. 1078 See the example from :mod:`cpp_gemm_template` for more details. 1079 """ 1080 1081 def __init__( 1082 self, 1083 wrapped_template_cls, 1084 preprocessor, 1085 postprocessor, 1086 **kwargs, 1087 ) -> None: 1088 if preprocessor is not None: 1089 self._preprocessor = preprocessor 1090 else: 1091 self._preprocessor = lambda x, y: (x, y) 1092 if postprocessor is not None: 1093 self._postprocessor = postprocessor 1094 else: 1095 self._postprocessor = lambda x: x 1096 assert "input_nodes" in kwargs 1097 assert "layout" in kwargs 1098 kwargs["input_nodes"], kwargs["layout"] = preprocessor( 1099 kwargs["input_nodes"], kwargs["layout"] 1100 ) 1101 self._wrapped = wrapped_template_cls(**kwargs) 1102 1103 def __getattr__(self, name): 1104 return getattr(self._wrapped, name) 1105 1106 def maybe_append_choice(self, choices, **kwargs): 1107 return type(self._wrapped).maybe_append_choice(self, choices, **kwargs) 1108 1109 def generate(self, **kwargs): 1110 choice_caller = self._wrapped.generate(**kwargs) 1111 return DataProcessorChoiceCallerWrapper( 1112 choice_caller, self._preprocessor, self._postprocessor 1113 ) 1114 1115 def __repr__(self) -> str: 1116 return f"DataProcessorTemplateWrapper({self._wrapped})" 1117 1118 1119class ErrorFromChoice(RuntimeError): 1120 def __init__(self, msg, choice: ChoiceCaller, inputs_str) -> None: 1121 msg += f"\nFrom choice {choice}\n{inputs_str}" 1122 super().__init__(msg) 1123 self.choice = choice 1124 1125 1126class NoValidChoicesError(RuntimeError): 1127 pass 1128 1129 1130@functools.lru_cache(None) 1131def get_env_num_workers() -> Optional[int]: 1132 if "TORCHINDUCTOR_COMPILE_THREADS" in os.environ: 1133 return int(os.environ["TORCHINDUCTOR_COMPILE_THREADS"]) 1134 return None 1135 1136 1137def create_inputs_key(input_nodes) -> str: 1138 return repr([AlgorithmSelectorCache.key_of(x) for x in input_nodes]) 1139 1140 1141def create_precompile_key( 1142 name: str, inputs_key: str, choices: List[ChoiceCaller] 1143) -> str: 1144 return ":".join( 1145 [ 1146 name, 1147 inputs_key, 1148 torch.get_float32_matmul_precision(), 1149 ] 1150 + [choice.hash_key() for choice in choices] 1151 ) 1152 1153 1154class AlgorithmSelectorCache(PersistentCache): 1155 def __init__(self, *args, **kwargs) -> None: 1156 super().__init__(*args, **kwargs) 1157 1158 # the autotuning will get occur in the scheduler, so there is 1159 # no guarantee that the first lowering for a given key will also be the 1160 # first to benchmark it. share a single precompilation function for all lowerings 1161 # of a particular key 1162 self.precompile_cache: Dict[str, Callable[[], None]] = {} 1163 # list of callbacks that are called after benchmarking 1164 self.feedback_saver_fns: List[ 1165 Callable[ 1166 [Dict[ChoiceCaller, float], str, List[Any], List[ChoiceCaller]], None 1167 ] 1168 ] = [] 1169 1170 def __call__( 1171 self, 1172 name, 1173 choices: List[ChoiceCaller], 1174 input_nodes, 1175 layout, 1176 # optional dict mapping arg indices to the functions 1177 # generating a torch.Tensor for that input from the 1178 # corresponding ir.Buffer. if passed for a given 1179 # arg, the function will be called instead of 1180 # generating a random torch.Tensor for benchmarking. 1181 input_gen_fns: Optional[Dict[int, Callable[[ir.Buffer], torch.Tensor]]] = None, 1182 precompilation_timeout_seconds: int = 60 * 60, 1183 return_multi_template=False, 1184 ): 1185 from .codegen.cuda.cuda_kernel import CUDATemplateCaller 1186 1187 # Templates selected with input_gen_fns require specific input data to avoid IMA 1188 # Passing custom input gen fns to benchmark_fusion NYI, so skip deferred template selection 1189 # TODO(jgong5): support multi-template on CPU 1190 if input_gen_fns is not None or layout.device.type == "cpu": 1191 return_multi_template = False 1192 1193 # TODO - assert that we have not mutating kernels here 1194 1195 # TODO(nmacchioni): remove once CI tests are fixed 1196 choices = [choice for choice in choices if choice is not None] 1197 1198 if mm_file_name := get_mm_log_filename(): 1199 M, K = input_nodes[-2].get_size()[:2] 1200 N = input_nodes[-1].get_size()[-1] 1201 append_to_log(mm_file_name, {"invoke": str((M, K, N))}) 1202 1203 if len(choices) == 0: 1204 backend_config = ( 1205 "max_autotune_gemm_backends" 1206 if name != "convolution" 1207 else "max_autotune_conv_backends" 1208 ) 1209 raise NoValidChoicesError( 1210 f"No choices to select, please consider adding ATEN into {backend_config} " 1211 "config (defined in torch/_inductor/config.py) to allow at least one choice. " 1212 ) 1213 log.debug("Max autotune selects from %s choices.", str(len(choices))) 1214 1215 if len(choices) == 1: 1216 if not isinstance(choices[0], CUDATemplateCaller): 1217 # CUDATemplateCaller still needs to go through autotuning process to retrieve workspace size. 1218 return choices[0].output_node() 1219 1220 @functools.lru_cache(None) 1221 def make_benchmark_fn(): 1222 return self.make_benchmark_fn(choices, input_nodes, layout, input_gen_fns) 1223 1224 inputs_key = create_inputs_key(input_nodes) 1225 1226 def precompile(choices) -> Callable[[], None]: 1227 def no_op(*args, **kwargs): 1228 return 1229 1230 if ( 1231 precompilation_timeout_seconds is None 1232 or precompilation_timeout_seconds <= 0 1233 ): 1234 return no_op 1235 1236 env_workers = get_env_num_workers() 1237 num_workers = env_workers if env_workers is not None else (len(choices)) 1238 1239 if num_workers <= 0: 1240 return no_op 1241 1242 # https://github.com/python/cpython/issues/106905 1243 if ( 1244 sys.version_info.major == 3 1245 and sys.version_info.minor == 11 1246 and sys.version_info.micro <= 8 1247 ): 1248 return no_op 1249 1250 # check local and global cache before precompiling 1251 timings = self.lookup( 1252 choices, 1253 name, 1254 inputs_key, 1255 benchmark=None, 1256 ) 1257 1258 if timings: 1259 return no_op 1260 1261 precompile_key = create_precompile_key(name, inputs_key, choices) 1262 if precompile_func := self.precompile_cache.get(precompile_key): 1263 return precompile_func 1264 1265 log.info( 1266 "Multithreaded precompilation for %d choices using %d worker threads", 1267 len(choices), 1268 num_workers, 1269 ) 1270 1271 # In rare circumstances, because python threads inherit global state, 1272 # thread pool executor can race and leave stdout/stderr in a state 1273 # different than the original values. we explicitly restore the state 1274 # here to avoid this issue. 1275 1276 initial_stdout = sys.stdout 1277 initial_stderr = sys.stderr 1278 1279 def precompile_with_captured_stdout(choice): 1280 with restore_stdout_stderr(initial_stdout, initial_stderr): 1281 return choice.precompile() 1282 1283 executor = ThreadPoolExecutor(max_workers=num_workers) 1284 1285 futures = {} 1286 for c in choices: 1287 if hasattr(c, "precompile"): 1288 future = executor.submit(precompile_with_captured_stdout, c) 1289 futures[future] = c 1290 1291 @functools.lru_cache(None) 1292 @restore_stdout_stderr(initial_stdout, initial_stderr) 1293 def wait_on_futures(): 1294 counters["inductor"]["select_algorithm_precompile"] += 1 1295 for future in as_completed( 1296 futures, 1297 timeout=precompilation_timeout_seconds, 1298 ): 1299 if e := future.exception(): 1300 log.error( 1301 "Exception %s for benchmark choice %s", e, futures[future] 1302 ) 1303 1304 executor.shutdown(wait=True) 1305 1306 self.precompile_cache[precompile_key] = wait_on_futures 1307 1308 return wait_on_futures 1309 1310 def autotune(choices): 1311 return make_benchmark_fn()(choices) 1312 1313 if config.autotune_in_subproc: 1314 from .autotune_process import tuning_pool 1315 1316 # do the optional warmup 1317 tuning_pool.initialize() 1318 1319 def do_autotuning(precompile_fn): 1320 precompile_start_ts = time.time() 1321 precompile_fn() 1322 precompile_elapse = time.time() - precompile_start_ts 1323 1324 autotune_start_ts = time.time() 1325 timings = self.lookup( 1326 choices, 1327 name, 1328 inputs_key, 1329 autotune, 1330 ) 1331 autotune_elapse = time.time() - autotune_start_ts 1332 1333 if timings and all( 1334 not math.isfinite(timing) for timing in timings.values() 1335 ): 1336 raise NoValidChoicesError 1337 1338 if make_benchmark_fn.cache_info().currsize: 1339 counters["inductor"]["select_algorithm_autotune"] += 1 1340 1341 if ( 1342 make_benchmark_fn.cache_info().currsize 1343 or log.getEffectiveLevel() == logging.DEBUG 1344 or config.trace.log_autotuning_results 1345 ): 1346 self.log_results( 1347 name, input_nodes, timings, autotune_elapse, precompile_elapse 1348 ) 1349 1350 for feedback_fn in self.feedback_saver_fns: 1351 feedback_fn(timings, name, input_nodes, choices) 1352 1353 return timings 1354 1355 precompile_fn = precompile(choices) 1356 1357 if return_multi_template and (config.max_autotune or config.max_autotune_gemm): 1358 1359 def get_timings(): 1360 timings = do_autotuning(precompile_fn) 1361 min_extern_choice = float("inf") 1362 for choice, timing in timings.items(): 1363 if isinstance(choice, ExternKernelCaller): 1364 min_extern_choice = min(min_extern_choice, timing) 1365 1366 timings = { 1367 choice: time 1368 for choice, time in timings.items() 1369 if ( 1370 time <= min_extern_choice 1371 or not isinstance(choice, ExternKernelCaller) 1372 ) 1373 } 1374 1375 return timings 1376 1377 return torch._inductor.ir.TensorBox.create( 1378 torch._inductor.ir.MultiTemplateBuffer( 1379 layout, 1380 input_nodes, 1381 get_timings, 1382 ) 1383 ) 1384 1385 # TODO - dont want to precompile if we have a cache hit 1386 timings = do_autotuning(precompile_fn) 1387 if timings == {} or choices[0] not in timings: 1388 return choices[0].output_node() 1389 1390 selected_key = builtins.min(timings, key=timings.__getitem__) 1391 selected_time = timings[selected_key] 1392 selected_choice = selected_key.output_node() 1393 log.debug("selected choice: %s", str(selected_choice)) 1394 return selected_choice 1395 1396 @classmethod 1397 def make_benchmark_fn( 1398 cls, 1399 choices, 1400 input_nodes, 1401 layout, 1402 input_gen_fns=None, 1403 ): 1404 if input_gen_fns is None: 1405 input_gen_fns = {} 1406 1407 def get_inputs(): 1408 # de-duplicate args 1409 unique_example_inputs = { 1410 x.get_name(): input_gen_fns.get(i, cls.benchmark_example_value)(x) 1411 for i, x in enumerate(input_nodes) 1412 } 1413 example_inputs = list(unique_example_inputs.values()) 1414 example_inputs_extern = [ 1415 unique_example_inputs[input_node.get_name()] 1416 if unique_example_inputs[input_node.get_name()].is_mkldnn 1417 else torch.as_strided( 1418 unique_example_inputs[input_node.get_name()], 1419 V.graph.sizevars.size_hints( 1420 input_node.get_size(), 1421 fallback=config.unbacked_symint_fallback, 1422 ), 1423 V.graph.sizevars.size_hints( 1424 input_node.get_stride(), 1425 fallback=config.unbacked_symint_fallback, 1426 ), 1427 V.graph.sizevars.size_hint( 1428 input_node.get_layout().offset, 1429 fallback=config.unbacked_symint_fallback, 1430 ), 1431 ) 1432 for input_node in input_nodes 1433 ] 1434 1435 out = cls.benchmark_example_value(layout) 1436 out_extern = torch.as_strided( 1437 out, out.size(), out.stride(), V.graph.sizevars.size_hint(layout.offset) 1438 ) 1439 expected = None 1440 if VERIFY: 1441 choices[0].benchmark(*example_inputs_extern, out=out_extern) 1442 expected = out_extern.clone() 1443 1444 return example_inputs, example_inputs_extern, out, out_extern, expected 1445 1446 if DEBUG: 1447 print(f"{len(choices)} tuning requests:") 1448 1449 def debug_str(example_inputs, out): 1450 def tensor_repr(x): 1451 return ( 1452 f"torch.empty_strided({tuple(x.size())!r}, {tuple(x.stride())!r}, " 1453 f"dtype={x.dtype!r}, device={x.device.type!r})" 1454 ) 1455 1456 lines = [ 1457 "inputs = [", 1458 ] 1459 for x in example_inputs: 1460 lines.append(f" {tensor_repr(x)},") 1461 lines += ["]", f"out = {tensor_repr(out)}", ""] 1462 return "\n".join(lines) 1463 1464 def benchmark_choice_in_current_process( 1465 choice, example_inputs, example_inputs_extern, out, out_extern, expected 1466 ): 1467 out.zero_() 1468 if isinstance(choice, ExternKernelCaller): 1469 # aten kernels want the offset baked in for sliced tensors 1470 result = choice.benchmark(*example_inputs_extern, out=out_extern) 1471 else: 1472 # triton templates want the base pointer for sliced tensors 1473 result = choice.benchmark(*example_inputs, out=out) 1474 if VERIFY and expected is not None: 1475 torch.testing.assert_close(out_extern, expected, **VERIFY) 1476 if torch.cuda.is_available(): 1477 torch.cuda.synchronize() # shake out any CUDA errors 1478 return result 1479 1480 def benchmark_in_current_process(choices): 1481 inputs = get_inputs() 1482 example_inputs, _, out, _, _ = inputs 1483 timings = {} 1484 for choice in choices: 1485 try: 1486 timing = benchmark_choice_in_current_process(choice, *inputs) 1487 except CUDACompileError as e: 1488 log.error( 1489 "CUDA compilation error during autotuning: \n%s. \nIgnoring this choice.", 1490 str(e), 1491 ) 1492 timing = float("inf") 1493 except NotImplementedError as e: 1494 log.warning("Not yet implemented: %s", e) 1495 timing = float("inf") 1496 except RuntimeError as e: 1497 msg = str(e) 1498 if "invalid argument" in msg: 1499 msg += "\n\nThis may mean this GPU is too small for max_autotune mode.\n\n" 1500 else: 1501 if "illegal memory access" in msg: 1502 msg += "\n\nEither error in template or triton bug.\n" 1503 log.error( 1504 "Runtime error during autotuning: \n%s. \nIgnoring this choice.", 1505 msg, 1506 ) 1507 timing = float("inf") 1508 except AssertionError as e: 1509 raise AssertionError( # noqa: B904 1510 f"Incorrect result from choice {choice}\n\n{e}" 1511 ) 1512 except Exception as e: 1513 try: 1514 from triton.runtime.autotuner import OutOfResources 1515 1516 if isinstance(e, OutOfResources): 1517 log.warning(e) 1518 timing = float("inf") 1519 else: 1520 raise e 1521 except ImportError: 1522 raise e from None 1523 1524 timings[choice] = timing 1525 1526 return timings 1527 1528 def benchmark_in_sub_process(choices): 1529 from . import autotune_process 1530 1531 # only benchmark triton kernel in sub process for now. 1532 # ATen/Extern kernel are still benchmarked in the current process. 1533 extern = [c for c in choices if isinstance(c, ExternKernelCaller)] 1534 triton = [c for c in choices if not isinstance(c, ExternKernelCaller)] 1535 1536 timings = benchmark_in_current_process(extern) 1537 timings.update(autotune_process.benchmark_in_sub_process(triton)) 1538 return timings 1539 1540 benchmark = ( 1541 benchmark_in_sub_process 1542 if config.autotune_in_subproc 1543 else benchmark_in_current_process 1544 ) 1545 1546 return benchmark 1547 1548 @staticmethod 1549 def log_results( 1550 name: str, 1551 input_nodes: List[ir.IRNode], 1552 timings: Dict[ChoiceCaller, float], 1553 elapse: float, 1554 precompile_elapse: float, 1555 ): 1556 V.debug.log_autotuning_results( 1557 name, input_nodes, timings, elapse, precompile_elapse 1558 ) 1559 if not (config.max_autotune or config.max_autotune_gemm) or not PRINT_AUTOTUNE: 1560 return 1561 sizes = ", ".join( 1562 [ 1563 "x".join( 1564 map( 1565 str, 1566 V.graph.sizevars.size_hints( 1567 n.get_size(), fallback=config.unbacked_symint_fallback 1568 ), 1569 ) 1570 ) 1571 for n in input_nodes 1572 ] 1573 ) 1574 1575 n = None if log.getEffectiveLevel() == logging.DEBUG else 10 1576 top_k = sorted(timings, key=timings.__getitem__)[:n] 1577 best = top_k[0] 1578 1579 def get_choice_info(choice): 1580 if isinstance(choice, torch._inductor.select_algorithm.ExternKernelCaller): 1581 return {"type": "cublas", "time": timings[choice]} 1582 1583 assert isinstance( 1584 choice, torch._inductor.select_algorithm.TritonTemplateCaller 1585 ) 1586 1587 info = choice.info_dict() 1588 tile = info["tile_shape"] 1589 1590 tile_vals = eval(tile) # type: ignore[arg-type] 1591 BLOCK_M = tile_vals[0] 1592 BLOCK_K = tile_vals[1] 1593 BLOCK_N = tile_vals[2] 1594 1595 return { 1596 "type": "triton", 1597 "time": timings[choice], 1598 "BLOCK_M": BLOCK_M, 1599 "BLOCK_K": BLOCK_K, 1600 "BLOCK_N": BLOCK_N, 1601 "num_stages": info["num_stages"], 1602 "num_warps": info["num_warps"], 1603 } 1604 1605 mm_filename = get_mm_log_filename() 1606 if mm_filename and "mm" in name: 1607 M, K = input_nodes[-2].get_size()[:2] 1608 N = input_nodes[-1].get_size()[-1] 1609 1610 out_dict = { 1611 str((M, K, N)): [get_choice_info(choice) for choice in timings.keys()] 1612 } 1613 1614 append_to_log(mm_filename, out_dict) 1615 1616 best_time = timings[best] 1617 sys.stderr.write(f"AUTOTUNE {name}({sizes})\n") 1618 for choice in top_k: 1619 result = timings[choice] 1620 if result: 1621 kernel_info = ( 1622 choice.debug_extra if hasattr(choice, "debug_extra") else "" 1623 ) 1624 sys.stderr.write( 1625 f" {choice.name} {result:.4f} ms {best_time / result:.1%} {kernel_info}\n" 1626 ) 1627 else: 1628 sys.stderr.write( 1629 f" {choice.name} {result:.4f} ms <DIVIDED BY ZERO ERROR>\n" 1630 ) 1631 1632 autotune_type_str = ( 1633 "SubProcess" if config.autotune_in_subproc else "SingleProcess" 1634 ) 1635 sys.stderr.write( 1636 f"{autotune_type_str} AUTOTUNE benchmarking takes {elapse:.4f} seconds and {precompile_elapse:.4f}" 1637 " seconds precompiling\n" 1638 ) 1639 1640 @staticmethod 1641 def benchmark_example_value(node): 1642 """ 1643 Convert an ir.Buffer into a concrete torch.Tensor we can use for 1644 benchmarking. 1645 """ 1646 if isinstance(node, ir.Layout): 1647 node = ir.Buffer("fake", node) 1648 # triton templates want the base tensor. 1649 if isinstance(node, ir.BaseView): 1650 node = node.unwrap_view() 1651 return AlgorithmSelectorCache.generate_example_value( 1652 V.graph.sizevars.size_hints( 1653 node.get_size(), 1654 fallback=config.unbacked_symint_fallback, 1655 ), 1656 V.graph.sizevars.size_hints( 1657 node.get_stride(), 1658 fallback=config.unbacked_symint_fallback, 1659 ), 1660 node.get_device(), 1661 node.get_dtype(), 1662 node.layout.offset, 1663 ) 1664 1665 @staticmethod 1666 def generate_example_value(size, stride, device, dtype, extra_size): 1667 # preserve rng states to avoid the rand_strided call below changes 1668 # the rng states for the real model code. 1669 with preserve_rng_state(): 1670 return rand_strided( 1671 size, 1672 stride, 1673 device=device, 1674 dtype=dtype, 1675 extra_size=extra_size, 1676 ) 1677 1678 @staticmethod 1679 def key_of(node): 1680 """ 1681 Extract the pieces of an ir.Buffer that we should invalidate cached 1682 autotuning results on. 1683 """ 1684 sizevars = V.graph.sizevars 1685 return ( 1686 node.get_device().type, 1687 str(node.get_dtype()), 1688 *sizevars.size_hints( 1689 node.get_size(), 1690 fallback=config.unbacked_symint_fallback, 1691 ), 1692 *sizevars.size_hints( 1693 node.get_stride(), 1694 fallback=config.unbacked_symint_fallback, 1695 ), 1696 sizevars.size_hint( 1697 node.get_layout().offset, 1698 fallback=config.unbacked_symint_fallback, 1699 ), 1700 ) 1701 1702 def add_feedback_saver( 1703 self, 1704 fn: Callable[ 1705 [Dict[ChoiceCaller, float], str, List[Any], List[ChoiceCaller]], None 1706 ], 1707 ): 1708 self.feedback_saver_fns.append(fn) 1709 1710 1711_ALGORITHM_SELECTOR_CACHE: Optional[AlgorithmSelectorCache] = None 1712 1713 1714def autotune_select_algorithm(*args, **kwargs): 1715 global _ALGORITHM_SELECTOR_CACHE 1716 if _ALGORITHM_SELECTOR_CACHE is None: 1717 _ALGORITHM_SELECTOR_CACHE = AlgorithmSelectorCache() 1718 1719 if "return_multi_template" not in kwargs: 1720 kwargs[ 1721 "return_multi_template" 1722 ] = torch._inductor.config.benchmark_epilogue_fusion 1723 1724 return _ALGORITHM_SELECTOR_CACHE(*args, **kwargs) 1725 1726 1727def add_feedback_saver( 1728 fn: Callable[[Dict[ChoiceCaller, float], str, List[Any], List[ChoiceCaller]], None] 1729): 1730 global _ALGORITHM_SELECTOR_CACHE 1731 if _ALGORITHM_SELECTOR_CACHE is None: 1732 _ALGORITHM_SELECTOR_CACHE = AlgorithmSelectorCache() 1733 _ALGORITHM_SELECTOR_CACHE.add_feedback_saver(fn) 1734 1735 1736def realize_inputs(*args): 1737 if len(args) == 1: 1738 return ir.ExternKernel.require_stride1(ir.ExternKernel.realize_input(args[0])) 1739 return [realize_inputs(x) for x in args] 1740 1741 1742# ensure lowering is imported so that `extern_kernels.*` is populated 1743from . import lowering # noqa: F401 1744