1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4import contextlib 5import ctypes 6import dataclasses 7import functools 8import logging 9import os 10import queue 11import time 12import warnings 13from concurrent.futures import ThreadPoolExecutor 14from ctypes import byref, c_size_t, c_void_p, CDLL 15from typing import ( 16 Any, 17 Callable, 18 Dict, 19 Iterable, 20 List, 21 Optional, 22 Sequence, 23 TYPE_CHECKING, 24 Union, 25) 26 27import torch 28import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools 29from torch import multiprocessing 30from torch._dynamo.testing import rand_strided 31from torch._inductor import ir 32from torch._inductor.codecache import ( 33 CppCodeCache, 34 CUDACodeCache, 35 DLLWrapper, 36 get_hash, 37 PyCodeCache, 38) 39 40 41if TYPE_CHECKING: 42 from multiprocessing.process import BaseProcess 43 from multiprocessing.queues import Queue 44 from types import ModuleType 45 46 from torch._inductor.select_algorithm import TritonTemplateCaller 47 48from . import config 49from .runtime.benchmarking import benchmarker 50from .virtualized import V 51 52 53CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES" 54EXIT_HANDLER_REGISTERED = False 55 56log = logging.getLogger(__name__) 57 58 59# Used to synchronize between parent and child processes 60class Ping: 61 pass 62 63 64class Pong: 65 pass 66 67 68class NonzeroWorkspaceNotSupportedError(Exception): 69 pass 70 71 72@contextlib.contextmanager 73def set_cuda_visible_device(device: Optional[int]): 74 """ 75 Context manager to set the CUDA_VISIBLE_DEVICES environment variable to the 76 specified single device. If device is None, don't manipulate the environment. 77 """ 78 if device is None: 79 yield 80 return 81 82 current = os.environ.get(CUDA_VISIBLE_DEVICES) 83 os.environ[CUDA_VISIBLE_DEVICES] = str(device) 84 try: 85 yield 86 finally: 87 if current is None: 88 del os.environ[CUDA_VISIBLE_DEVICES] 89 else: 90 os.environ[CUDA_VISIBLE_DEVICES] = current 91 92 93@dataclasses.dataclass 94class TuningProcess: 95 """ 96 Abstraction for launching a helper process to benchmark kernels. Spawns 97 the parent process and uses multiprocessing queues to send benchmark 98 requests and return results. 99 """ 100 101 device: Optional[int] = None 102 process: Optional[BaseProcess] = None 103 request_queue: Optional[Queue[Any]] = None 104 response_queue: Optional[Queue[Any]] = None 105 106 @staticmethod 107 def process_main( 108 request_queue: Queue[Any], 109 response_queue: Queue[Any], 110 ) -> None: 111 """ 112 Entry point for the child process. 113 """ 114 log.debug( 115 "Entering TuningProcess child. Visible devices = %s", 116 os.environ.get(CUDA_VISIBLE_DEVICES), 117 ) 118 try: 119 TuningProcess.workloop(request_queue, response_queue) 120 except Exception as ex: 121 log.exception("Exception in TuningProcess") 122 123 @staticmethod 124 def workloop(request_queue: Queue[Any], response_queue: Queue[Any]) -> None: 125 """ 126 Work loop for the benchmarking subprocess. 127 """ 128 while True: 129 obj = request_queue.get() 130 131 if obj is None: 132 break # None is a sentinel for the child to terminate 133 elif isinstance(obj, Ping): 134 response_queue.put(Pong()) 135 elif isinstance(obj, BenchmarkRequest): 136 response_queue.put(obj.benchmark()) 137 else: 138 raise RuntimeError(f"Invalid request type {type(obj)}") 139 140 def valid(self) -> bool: 141 """ 142 True if the sub-process has been initialized. 143 """ 144 return ( 145 self.process is not None 146 and self.request_queue is not None 147 and self.response_queue is not None 148 ) 149 150 def clear(self) -> None: 151 """ 152 Reset to an uninitialized state. 153 """ 154 self.process = self.request_queue = self.response_queue = None 155 156 def initialize(self) -> None: 157 """ 158 Create child process, request/response queues, and do the warm up. 159 Set the environment to make only the provided GPU device visible 160 to the process. 161 """ 162 if self.valid(): 163 return 164 165 # cuda runtime does not work with "fork", use "spawn" to start processes. 166 ctx = multiprocessing.get_context("spawn") 167 self.request_queue = ctx.Queue() 168 self.response_queue = ctx.Queue() 169 170 self.process = ctx.Process( 171 target=self.process_main, 172 args=( 173 self.request_queue, 174 self.response_queue, 175 ), 176 ) 177 assert self.process is not None 178 with set_cuda_visible_device(self.device): 179 self.process.start() 180 181 def put(self, obj: Any) -> None: 182 """ 183 Push a work item to the child process. 184 """ 185 # In case of a prior crash, ensure the subprocess is running 186 self.initialize() 187 assert self.request_queue is not None 188 self.request_queue.put(obj) 189 190 def get( 191 self, result_timeout=120.0, graceful_timeout=3.0, terminate_timeout=1.0 192 ) -> Any: 193 """ 194 Get a response from the child process. Raises queue.Empty on timeout 195 or if the process dies. 196 197 This method is (so far) only used by TuningProcessPool, where torch._inductor.config entries are being used 198 to populate the timeouts: 199 200 Arguments: 201 202 @param result_timeout: Timeout in seconds, defaults to 120.0 or to 203 config.max_autotune_subproc_result_timeout_seconds when called by TuningProcessPool 204 @param graceful_timeout: Timeout in seconds to allow graceful shutdown (SIGTERM is sent after this time). 205 Defaults to 3.0 or to config.max_autotune_subproc_graceful_timeout_seconds 206 @param terminate_timeout: Timeout in seconds after SIGTERM, until we send SIGKILL if the process 207 remains alive. Defaults to 1.0 or to 208 config.max_autotune_subproc_terminate_timeout_seconds. 209 Returns: 210 A response from the child process (Any type) 211 """ 212 assert self.process is not None 213 assert self.response_queue is not None 214 while True: 215 try: 216 remaining_timeout = result_timeout 217 res = None 218 while remaining_timeout is not None and remaining_timeout >= 1.0: 219 remaining_timeout -= 0.5 220 try: 221 res = self.response_queue.get(timeout=0.5) 222 break 223 except queue.Empty: 224 if not self.process.is_alive(): 225 raise # is being caught a few lines below 226 if res is None: 227 res = self.response_queue.get(timeout=remaining_timeout) 228 return res 229 except queue.Empty: 230 status = self.process.exitcode 231 if status is None: 232 self.kill( 233 graceful_timeout=graceful_timeout, 234 terminate_timeout=terminate_timeout, 235 ) 236 else: 237 # child process crashed 238 self.clear() 239 raise 240 241 def terminate(self) -> None: 242 """ 243 Signal the child process to terminate. 244 """ 245 if self.valid(): 246 assert self.process is not None 247 assert self.request_queue is not None 248 self.request_queue.put(None) 249 250 def wait(self) -> None: 251 """ 252 Wait for the child process to exit. 253 """ 254 if self.process is not None: 255 self.process.join() 256 self.clear() 257 258 def kill(self, graceful_timeout=5.0, terminate_timeout=1.0) -> None: 259 # Tries to kill the process, using a graceful_timeout in which the process 260 # is allowed to exit gracefully. If the process is still alive, 261 # it will be terminated. If that is not sufficient to end it 262 # within terminate_timeout seconds, it will be killed. 263 if self.process is not None: 264 self.terminate() 265 self.process.join(timeout=graceful_timeout) 266 if self.process.is_alive(): 267 log.warning( 268 "Sending SIGTERM to process with PID %d", 269 self.process.pid, 270 ) 271 self.process.terminate() 272 self.process.join(timeout=terminate_timeout) 273 if self.process.is_alive(): 274 log.error( 275 "Sending SIGKILL to process with PID %d", 276 self.process.pid, 277 ) 278 self.process.kill() # This should definitely end the process 279 self.clear() 280 281 282@dataclasses.dataclass 283class TuningProcessPool: 284 """ 285 Maintains a pool of TuningProcesses to benchmark kernels in parallel 286 across devices. By default, we create one TuningProcess per device and 287 set the sub-process environment to make only that device visible. 288 """ 289 290 processes: Optional[queue.Queue[TuningProcess]] = None 291 executor: Optional[ThreadPoolExecutor] = None 292 293 def initialize(self) -> None: 294 """ 295 Start the child processes. 296 """ 297 assert (self.processes is None) == (self.executor is None) 298 if self.processes is not None: 299 return 300 301 devices = self.get_device_list() 302 log.debug("Sub-process autotune device list: %s", devices) 303 304 # Launch the child processes and push a msg to "warm up" 305 self.processes = queue.Queue() 306 for device in devices: 307 p = TuningProcess(device=device) 308 p.initialize() 309 p.put(Ping()) 310 self.processes.put(p) 311 312 # Wait for the initialization to finish 313 for p in self.processes.queue: 314 assert isinstance(p.get(result_timeout=None), Pong) 315 316 # Use a thread pool to manage distributing work to the subprocesses. 317 # Threads block on an available process, so it makes sense to match 318 # the number of threads with the number of devices. 319 self.executor = ThreadPoolExecutor(max_workers=len(devices)) 320 321 # Register the exit handler for the parent process so it will terminate 322 # the child processes. 323 global EXIT_HANDLER_REGISTERED 324 if not EXIT_HANDLER_REGISTERED: 325 EXIT_HANDLER_REGISTERED = True 326 import atexit 327 328 atexit.register(self.terminate) 329 330 def get_device_list(self) -> Sequence[Optional[int]]: 331 """ 332 Gather the list of devices to be used in the pool. 333 """ 334 if not config.autotune_multi_device: 335 # Don't use multiple devices 336 return [None] 337 338 count = torch.cuda.device_count() 339 340 # If the user specified the visible devices in the env, use those. 341 if CUDA_VISIBLE_DEVICES in os.environ: 342 devices = [int(d) for d in os.environ[CUDA_VISIBLE_DEVICES].split(",")] 343 assert len(devices) <= count 344 return devices 345 346 return list(range(count)) 347 348 def terminate(self) -> None: 349 """ 350 Signal all child processes to terminate. 351 """ 352 if self.executor is not None: 353 self.executor.shutdown() 354 self.executor = None 355 356 if self.processes is not None: 357 for p in self.processes.queue: 358 p.terminate() 359 for p in self.processes.queue: 360 p.wait() 361 self.processes = None 362 363 def target(self, choice: TritonTemplateCaller) -> float: 364 """ 365 Entry point for the thread-pool helper threads: Wait for an open TuningProcess, 366 remove it from the queue, execute the benchmark in that subprocess, and return 367 the TuningProcess to the queue. 368 """ 369 assert choice.bmreq is not None 370 assert self.processes is not None 371 372 process = self.processes.get() 373 process.put(choice.bmreq) 374 try: 375 return process.get( 376 config.max_autotune_subproc_result_timeout_seconds, 377 config.max_autotune_subproc_graceful_timeout_seconds, 378 config.max_autotune_subproc_terminate_timeout_seconds, 379 ) 380 except queue.Empty: 381 warnings.warn( 382 f"Failed to benchmark choice '{choice}'. It will be ignored. " 383 "Please debug the root cause in case the choice can bring perf gains." 384 ) 385 # set to INF so this choice will be ignored 386 return float("inf") 387 finally: 388 self.processes.put(process) 389 390 def benchmark( 391 self, 392 choices: List[TritonTemplateCaller], 393 ) -> Dict[TritonTemplateCaller, float]: 394 """ 395 Benchmark each choice in a separate process. 396 """ 397 assert self.processes is not None, "Tuning process pool is not initialized" 398 assert self.executor is not None 399 400 results = {} 401 402 # Use a ThreadExecutorPool to spread the work across the subprocesses and 403 # to grab subprocesses as soon as they're free. 404 for choice, result in zip(choices, self.executor.map(self.target, choices)): 405 results[choice] = result 406 407 return results 408 409 410tuning_pool = TuningProcessPool() 411 412 413LayoutOrBuffer = Union[ir.Layout, ir.Buffer] 414 415 416@dataclasses.dataclass 417class TensorMeta: 418 device: torch.device 419 dtype: torch.dtype 420 sizes: torch._prims_common.ShapeType 421 strides: torch._prims_common.StrideType 422 offset: int 423 name: Optional[str] = None 424 425 @classmethod 426 def from_irnodes( 427 cls, irnodes: Union[LayoutOrBuffer, Sequence[LayoutOrBuffer]] 428 ) -> Union[TensorMeta, List[TensorMeta]]: 429 if isinstance(irnodes, Sequence): 430 result: List[Any] = [cls.from_irnodes(x) for x in irnodes] 431 assert all(isinstance(x, TensorMeta) for x in result) 432 return result 433 434 node = irnodes 435 if isinstance(node, ir.Layout): 436 node = ir.Buffer("fake", node) 437 438 dtype = node.get_dtype() 439 assert dtype is not None 440 441 return TensorMeta( 442 device=node.get_device(), 443 dtype=dtype, 444 sizes=V.graph.sizevars.size_hints( 445 node.get_size(), 446 fallback=config.unbacked_symint_fallback, 447 ), 448 strides=V.graph.sizevars.size_hints( 449 node.get_stride(), 450 fallback=config.unbacked_symint_fallback, 451 ), 452 offset=V.graph.sizevars.size_hint( 453 node.get_layout().offset, 454 fallback=config.unbacked_symint_fallback, 455 ), 456 name=node.get_name(), 457 ) 458 459 def to_tensor(self) -> torch.Tensor: 460 return rand_strided( 461 self.sizes, 462 self.strides, 463 device=self.device, 464 dtype=self.dtype, 465 extra_size=self.offset, 466 ) 467 468 469@dataclasses.dataclass 470class BenchmarkRequest: 471 """ 472 Only handle triton template benchmark for now. The extern kernel benchmark 473 can be done inside the same process since they usually don't cause crash. 474 475 Important: Instances of this class and subclasses have to be serializable 476 across process boundaries. Do not put CUDA Tensors in here! 477 """ 478 479 def __init__( 480 self, 481 kernel_name: str, 482 input_tensor_meta: Union[TensorMeta, List[TensorMeta]], 483 output_tensor_meta: Union[TensorMeta, List[TensorMeta]], 484 extra_args: Iterable[Any], 485 ) -> None: 486 # the kernel name defined in the module 487 self.kernel_name = kernel_name 488 489 if isinstance(input_tensor_meta, TensorMeta): 490 input_tensor_meta = [input_tensor_meta] 491 self.input_tensor_meta = input_tensor_meta 492 493 if isinstance(output_tensor_meta, (tuple, list)): 494 assert len(output_tensor_meta) == 1 495 output_tensor_meta = output_tensor_meta[0] 496 self.output_tensor_meta = output_tensor_meta 497 498 self.extra_args = extra_args 499 500 def make_run_fn( 501 self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor 502 ) -> Callable[[], None]: 503 raise NotImplementedError 504 505 def cleanup_run_fn(self) -> None: 506 pass 507 508 def do_bench( 509 self, 510 fn, 511 *input_tensors: torch.Tensor, 512 output_tensor: Optional[torch.Tensor] = None, 513 ) -> float: 514 raise NotImplementedError 515 516 def benchmark( 517 self, 518 *input_tensors: torch.Tensor, 519 output_tensor: Optional[torch.Tensor] = None, 520 ) -> float: 521 debug = log.isEnabledFor(logging.DEBUG) 522 if debug: 523 start_ts = time.time() 524 525 # create args and out tensor 526 if output_tensor is None: 527 assert len(input_tensors) == 0 528 input_tensors = tuple(x.to_tensor() for x in self.input_tensor_meta) 529 output_tensor = self.output_tensor_meta.to_tensor() 530 531 if debug: 532 create_tensor_elapse = time.time() - start_ts # type: ignore[possibly-undefined] 533 start_ts = time.time() 534 try: 535 fn = self.make_run_fn(*input_tensors, output_tensor=output_tensor) 536 except NonzeroWorkspaceNotSupportedError: 537 # Skipping all ops with nonzero workspace requirements 538 log.info("Skipping op due to nonzero workspace requirement") 539 return float("inf") 540 541 if debug: 542 load_elapse = time.time() - start_ts # type: ignore[possibly-undefined] 543 start_ts = time.time() 544 545 out = self.do_bench(fn, *input_tensors, output_tensor) 546 547 if debug: 548 bench_elapse = time.time() - start_ts # type: ignore[possibly-undefined] 549 log.debug( 550 "InChildProcess %s: load %f, create tensor %f, bench %f", 551 str(self), 552 load_elapse, # type: ignore[possibly-undefined] 553 create_tensor_elapse, # type: ignore[possibly-undefined] 554 bench_elapse, 555 ) 556 self.cleanup_run_fn() 557 return out 558 559 560class TestBenchmarkRequest(BenchmarkRequest): 561 """ 562 Supports unit testing. Defined in this file so that the TuningProcess 563 sub-process knows how to unpickle these objects. 564 """ 565 566 def __init__(self, value: Optional[float] = None) -> None: 567 self.value = value 568 569 def benchmark( 570 self, *input_tensors: torch.Tensor, output_tensor: Optional[torch.Tensor] = None 571 ) -> float: 572 if self.value is None: 573 raise Exception("Failed to run") # noqa: TRY002 574 return self.value 575 576 577class GPUDeviceBenchmarkRequest(BenchmarkRequest): 578 def do_bench( 579 self, 580 fn, 581 *input_tensors: torch.Tensor, 582 output_tensor: Optional[torch.Tensor] = None, 583 ) -> float: 584 device_idx_set = { 585 tensor.device.index 586 for tensor in [*input_tensors, output_tensor] 587 if isinstance(tensor, torch.Tensor) 588 and tensor.is_cuda 589 and tensor.device.index is not None 590 } 591 assert len(device_idx_set) <= 1, f"Can not mix devices {device_idx_set}" 592 if len(device_idx_set) == 1: 593 device_idx = next(iter(device_idx_set)) 594 else: 595 device_idx = torch.cuda.current_device() 596 597 with torch.cuda.device(device_idx): 598 out = benchmarker.benchmark_gpu(fn) 599 torch.cuda.synchronize() # shake out any CUDA errors 600 601 return out 602 603 604class TritonBenchmarkRequest(GPUDeviceBenchmarkRequest): 605 # Important: Instances of this class have to be serializable 606 # across process boundaries. Do not put CUDA Tensors in here! 607 def __init__( 608 self, 609 kernel_name: str, 610 input_tensor_meta: Union[TensorMeta, List[TensorMeta]], 611 output_tensor_meta: Union[TensorMeta, List[TensorMeta]], 612 extra_args: Iterable[Any], 613 module_path: str, # the path of the module defining the triton kernel 614 module_cache_key: str, 615 grid: List[int], 616 num_stages: int, 617 num_warps: int, 618 matrix_instr_nonkdim: int = 0, # only used for hip to choose the shape of mfma instruction. 619 ) -> None: 620 super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args) 621 self.module_path = module_path 622 self.module_cache_key = module_cache_key 623 self.grid = grid 624 self.num_stages = num_stages 625 self.num_warps = num_warps 626 self.matrix_instr_nonkdim = matrix_instr_nonkdim 627 628 def make_run_fn( 629 self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor 630 ) -> Callable[[], None]: 631 mod = PyCodeCache.load_by_key_path(self.module_cache_key, self.module_path) 632 log.debug( 633 "benchmark module key: %s, path: %s", 634 self.module_cache_key, 635 self.module_path, 636 ) 637 638 run_method = getattr(mod, self.kernel_name).run 639 extra_args = list(self.extra_args) 640 641 # Newer version of triton add warmup argument to JITFunction.run. 642 # This code handles backward-compatibility. 643 warmup_arg = {} 644 import inspect 645 646 if "warmup" in inspect.signature(run_method).parameters: 647 warmup_arg["warmup"] = False 648 649 from torch._C import _cuda_getCurrentRawStream as get_raw_stream 650 651 if torch.version.hip and self.matrix_instr_nonkdim != 0: 652 return functools.partial( 653 run_method, 654 *input_tensors, 655 output_tensor, 656 *self.extra_args, 657 grid=self.grid, 658 **warmup_arg, 659 stream=get_raw_stream(self.output_tensor_meta.device.index), 660 ) 661 else: 662 return functools.partial( 663 run_method, 664 *input_tensors, 665 output_tensor, 666 *self.extra_args, 667 grid=self.grid, 668 **warmup_arg, 669 stream=get_raw_stream(self.output_tensor_meta.device.index), 670 ) 671 672 def precompile(self): 673 mod = PyCodeCache.load_by_key_path(self.module_cache_key, self.module_path) 674 getattr(mod, self.kernel_name).precompile() 675 676 def __str__(self) -> str: 677 return f"{self.kernel_name=}, {self.module_path=}, {self.module_cache_key=}" 678 679 680class CUDABenchmarkRequest(GPUDeviceBenchmarkRequest): 681 # Important: Instances of this class have to be serializable 682 # across process boundaries. Do not put CUDA Tensors in here! 683 684 def __init__( 685 self, 686 kernel_name: str, 687 input_tensor_meta: Union[TensorMeta, List[TensorMeta]], 688 output_tensor_meta: Union[TensorMeta, List[TensorMeta]], 689 extra_args: Iterable[Any], 690 source_code: str, 691 ) -> None: 692 super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args) 693 self.source_code = source_code 694 self.workspace_size: int = 0 695 self.workspace: Optional[torch.Tensor] = None 696 self.DLL: Optional[DLLWrapper] = None 697 self._workspace_size_updated = False 698 self.hash_key: str = "" 699 self.source_file: str = "" 700 self.hash_key, self.source_file = CUDACodeCache.write(self.source_code, "so") 701 702 def precompile(self): 703 # Prepopulate CUDACodeCache 704 # may happen in separate Threadpool 705 log.debug("Precompiling %s", self) 706 CUDACodeCache.compile(self.source_code, "so") 707 log.debug("Done precompiling %s", self) 708 709 def make_run_fn( 710 self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor 711 ) -> Callable[[], None]: 712 self.ensure_dll_loaded() 713 self.update_workspace_size() 714 args = [ 715 c_void_p(tensor.data_ptr()) 716 for tensor in list(input_tensors) + [output_tensor] 717 ] 718 log.debug( 719 "make_run_fn: self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s", 720 self.kernel_name, 721 self.source_file, 722 self.hash_key, 723 self.DLL, 724 args, 725 self.extra_args, 726 ) 727 stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream) 728 run_method = getattr(self.DLL, self.kernel_name) 729 workspace_ptr = c_void_p(0) 730 if self.workspace_size > 0: 731 self.workspace = torch.zeros( 732 (self.workspace_size + 7) // 8, 733 dtype=torch.float64, 734 device=output_tensor.device, 735 ) 736 workspace_ptr = c_void_p(self.workspace.data_ptr()) 737 738 # Generate partial function. 739 return functools.partial( 740 run_method, 741 *args, 742 *self.extra_args, 743 None, # null workspace size ptr 744 workspace_ptr, # set workspace ptr, 745 stream_ptr, 746 ) 747 748 def update_workspace_size(self) -> None: 749 if self._workspace_size_updated: 750 return 751 self.ensure_dll_loaded() 752 unique_input_count = len({meta.name for meta in self.input_tensor_meta}) 753 args = [c_void_p(None) for _ in range(unique_input_count + 1)] 754 stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream) 755 756 run_method = getattr(self.DLL, self.kernel_name) 757 # Retrieve workspace_size and initialize workspace. 758 c_workspace_size = c_size_t() 759 run_method( 760 *args, # input ptrs and output ptrs 761 *self.extra_args, 762 byref( 763 c_workspace_size 764 ), # set workspace size ptr to retrieve workspace size 765 None, # null workspace ptr 766 stream_ptr, 767 ) 768 torch.cuda.synchronize() # shake out any CUDA errors 769 self.workspace_size = c_workspace_size.value 770 log.debug( 771 "update_workspace_size called: new workspace size=%d, self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s", # noqa: B950 772 self.workspace_size, 773 self.kernel_name, 774 self.source_file, 775 self.hash_key, 776 self.DLL, 777 args, 778 self.extra_args, 779 ) 780 self._workspace_size_updated = True 781 782 def ensure_dll_loaded(self): 783 if self.DLL is None: 784 self.DLL, self.hash_key, self.source_file = CUDACodeCache.load( 785 self.source_code, "so" 786 ) 787 788 def cleanup_run_fn(self) -> None: 789 if self.DLL is not None: 790 self.DLL.close() 791 self.workspace = None 792 793 def __str__(self) -> str: 794 return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}" 795 796 797class CPUDeviceBenchmarkRequest(BenchmarkRequest): 798 def do_bench( 799 self, 800 fn, 801 *input_tensors: torch.Tensor, 802 output_tensor: Optional[torch.Tensor] = None, 803 ) -> float: 804 return benchmarker.benchmark_cpu(fn) 805 806 807class CppBenchmarkRequest(CPUDeviceBenchmarkRequest): 808 # Important: Instances of this class have to be serializable 809 # across process boundaries. Do not put Tensors in here! 810 811 def __init__( 812 self, 813 kernel_name: str, 814 input_tensor_meta: Union[TensorMeta, List[TensorMeta]], 815 output_tensor_meta: Union[TensorMeta, List[TensorMeta]], 816 extra_args: Iterable[Any], 817 source_code: str, 818 ) -> None: 819 super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args) 820 self.source_code = source_code 821 self.hash_key = get_hash(source_code) 822 self.DLL: Optional[Union[CDLL, ModuleType]] = None 823 824 def precompile(self): 825 # Prepopulate CppCodeCache 826 # may happen in separate Threadpool 827 log.debug("Precompiling %s", self) 828 CppCodeCache.load(self.source_code, cuda=False) 829 log.debug("Done precompiling %s", self) 830 831 def make_run_fn( 832 self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor 833 ) -> Callable[[], None]: 834 # TODO(jgong5): use CppPythonBindingsCodeCache for better binding perf 835 self.DLL = CppCodeCache.load(self.source_code, cuda=False) 836 args = [tensor.data_ptr() for tensor in list(input_tensors) + [output_tensor]] 837 log.debug( 838 "make_run_fn: self.kernel_name=%s, self.DLL=%s, args=%s, self.extra_args=%s", 839 self.kernel_name, 840 self.DLL, 841 args, 842 self.extra_args, 843 ) 844 run_method = getattr(self.DLL, self.kernel_name) 845 # Assume only size with type ctypes.c_ulonglong in extra_args 846 assert all(isinstance(arg, ctypes.c_ulonglong) for arg in self.extra_args) 847 run_method.argtypes = [ctypes.c_ulonglong] * ( 848 len(args) + len(list(self.extra_args)) 849 ) 850 851 # Generate partial function. 852 return functools.partial( 853 run_method, 854 *args, 855 *self.extra_args, 856 ) 857 858 def cleanup_run_fn(self) -> None: 859 if self.DLL is not None: 860 """ 861 Check close attr due to it crash on Windows. 862 """ 863 if hasattr(self.DLL, "close"): 864 self.DLL.close() 865 866 def __str__(self) -> str: 867 return f"{self.kernel_name=}" 868 869 870def benchmark_in_sub_process( 871 choices: List[TritonTemplateCaller], 872) -> Dict[TritonTemplateCaller, float]: 873 """ 874 Do benchmarking in a subprocess and return the perf number (latency). 875 """ 876 return tuning_pool.benchmark(choices) 877