xref: /aosp_15_r20/external/pytorch/torch/_inductor/autotune_process.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4import contextlib
5import ctypes
6import dataclasses
7import functools
8import logging
9import os
10import queue
11import time
12import warnings
13from concurrent.futures import ThreadPoolExecutor
14from ctypes import byref, c_size_t, c_void_p, CDLL
15from typing import (
16    Any,
17    Callable,
18    Dict,
19    Iterable,
20    List,
21    Optional,
22    Sequence,
23    TYPE_CHECKING,
24    Union,
25)
26
27import torch
28import torch._inductor.async_compile  # noqa: F401 required to warm up AsyncCompile pools
29from torch import multiprocessing
30from torch._dynamo.testing import rand_strided
31from torch._inductor import ir
32from torch._inductor.codecache import (
33    CppCodeCache,
34    CUDACodeCache,
35    DLLWrapper,
36    get_hash,
37    PyCodeCache,
38)
39
40
41if TYPE_CHECKING:
42    from multiprocessing.process import BaseProcess
43    from multiprocessing.queues import Queue
44    from types import ModuleType
45
46    from torch._inductor.select_algorithm import TritonTemplateCaller
47
48from . import config
49from .runtime.benchmarking import benchmarker
50from .virtualized import V
51
52
53CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
54EXIT_HANDLER_REGISTERED = False
55
56log = logging.getLogger(__name__)
57
58
59# Used to synchronize between parent and child processes
60class Ping:
61    pass
62
63
64class Pong:
65    pass
66
67
68class NonzeroWorkspaceNotSupportedError(Exception):
69    pass
70
71
72@contextlib.contextmanager
73def set_cuda_visible_device(device: Optional[int]):
74    """
75    Context manager to set the CUDA_VISIBLE_DEVICES environment variable to the
76    specified single device. If device is None, don't manipulate the environment.
77    """
78    if device is None:
79        yield
80        return
81
82    current = os.environ.get(CUDA_VISIBLE_DEVICES)
83    os.environ[CUDA_VISIBLE_DEVICES] = str(device)
84    try:
85        yield
86    finally:
87        if current is None:
88            del os.environ[CUDA_VISIBLE_DEVICES]
89        else:
90            os.environ[CUDA_VISIBLE_DEVICES] = current
91
92
93@dataclasses.dataclass
94class TuningProcess:
95    """
96    Abstraction for launching a helper process to benchmark kernels. Spawns
97    the parent process and uses multiprocessing queues to send benchmark
98    requests and return results.
99    """
100
101    device: Optional[int] = None
102    process: Optional[BaseProcess] = None
103    request_queue: Optional[Queue[Any]] = None
104    response_queue: Optional[Queue[Any]] = None
105
106    @staticmethod
107    def process_main(
108        request_queue: Queue[Any],
109        response_queue: Queue[Any],
110    ) -> None:
111        """
112        Entry point for the child process.
113        """
114        log.debug(
115            "Entering TuningProcess child. Visible devices = %s",
116            os.environ.get(CUDA_VISIBLE_DEVICES),
117        )
118        try:
119            TuningProcess.workloop(request_queue, response_queue)
120        except Exception as ex:
121            log.exception("Exception in TuningProcess")
122
123    @staticmethod
124    def workloop(request_queue: Queue[Any], response_queue: Queue[Any]) -> None:
125        """
126        Work loop for the benchmarking subprocess.
127        """
128        while True:
129            obj = request_queue.get()
130
131            if obj is None:
132                break  # None is a sentinel for the child to terminate
133            elif isinstance(obj, Ping):
134                response_queue.put(Pong())
135            elif isinstance(obj, BenchmarkRequest):
136                response_queue.put(obj.benchmark())
137            else:
138                raise RuntimeError(f"Invalid request type {type(obj)}")
139
140    def valid(self) -> bool:
141        """
142        True if the sub-process has been initialized.
143        """
144        return (
145            self.process is not None
146            and self.request_queue is not None
147            and self.response_queue is not None
148        )
149
150    def clear(self) -> None:
151        """
152        Reset to an uninitialized state.
153        """
154        self.process = self.request_queue = self.response_queue = None
155
156    def initialize(self) -> None:
157        """
158        Create child process, request/response queues, and do the warm up.
159        Set the environment to make only the provided GPU device visible
160        to the process.
161        """
162        if self.valid():
163            return
164
165        # cuda runtime does not work with "fork", use "spawn" to start processes.
166        ctx = multiprocessing.get_context("spawn")
167        self.request_queue = ctx.Queue()
168        self.response_queue = ctx.Queue()
169
170        self.process = ctx.Process(
171            target=self.process_main,
172            args=(
173                self.request_queue,
174                self.response_queue,
175            ),
176        )
177        assert self.process is not None
178        with set_cuda_visible_device(self.device):
179            self.process.start()
180
181    def put(self, obj: Any) -> None:
182        """
183        Push a work item to the child process.
184        """
185        # In case of a prior crash, ensure the subprocess is running
186        self.initialize()
187        assert self.request_queue is not None
188        self.request_queue.put(obj)
189
190    def get(
191        self, result_timeout=120.0, graceful_timeout=3.0, terminate_timeout=1.0
192    ) -> Any:
193        """
194        Get a response from the child process. Raises queue.Empty on timeout
195        or if the process dies.
196
197        This method is (so far) only used by TuningProcessPool, where torch._inductor.config entries are being used
198        to populate the timeouts:
199
200        Arguments:
201
202            @param result_timeout: Timeout in seconds, defaults to 120.0 or to
203                                   config.max_autotune_subproc_result_timeout_seconds when called by TuningProcessPool
204            @param graceful_timeout: Timeout in seconds to allow graceful shutdown (SIGTERM is sent after this time).
205                                    Defaults to 3.0 or to config.max_autotune_subproc_graceful_timeout_seconds
206            @param terminate_timeout: Timeout in seconds after SIGTERM, until we send SIGKILL if the process
207                                      remains alive. Defaults to 1.0 or to
208                                      config.max_autotune_subproc_terminate_timeout_seconds.
209        Returns:
210            A response from the child process (Any type)
211        """
212        assert self.process is not None
213        assert self.response_queue is not None
214        while True:
215            try:
216                remaining_timeout = result_timeout
217                res = None
218                while remaining_timeout is not None and remaining_timeout >= 1.0:
219                    remaining_timeout -= 0.5
220                    try:
221                        res = self.response_queue.get(timeout=0.5)
222                        break
223                    except queue.Empty:
224                        if not self.process.is_alive():
225                            raise  # is being caught a few lines below
226                if res is None:
227                    res = self.response_queue.get(timeout=remaining_timeout)
228                return res
229            except queue.Empty:
230                status = self.process.exitcode
231                if status is None:
232                    self.kill(
233                        graceful_timeout=graceful_timeout,
234                        terminate_timeout=terminate_timeout,
235                    )
236                else:
237                    # child process crashed
238                    self.clear()
239                raise
240
241    def terminate(self) -> None:
242        """
243        Signal the child process to terminate.
244        """
245        if self.valid():
246            assert self.process is not None
247            assert self.request_queue is not None
248            self.request_queue.put(None)
249
250    def wait(self) -> None:
251        """
252        Wait for the child process to exit.
253        """
254        if self.process is not None:
255            self.process.join()
256            self.clear()
257
258    def kill(self, graceful_timeout=5.0, terminate_timeout=1.0) -> None:
259        # Tries to kill the process, using a graceful_timeout in which the process
260        # is allowed to exit gracefully. If the process is still alive,
261        # it will be terminated. If that is not sufficient to end it
262        # within terminate_timeout seconds, it will be killed.
263        if self.process is not None:
264            self.terminate()
265            self.process.join(timeout=graceful_timeout)
266            if self.process.is_alive():
267                log.warning(
268                    "Sending SIGTERM to process with PID %d",
269                    self.process.pid,
270                )
271                self.process.terminate()
272                self.process.join(timeout=terminate_timeout)
273                if self.process.is_alive():
274                    log.error(
275                        "Sending SIGKILL to process with PID %d",
276                        self.process.pid,
277                    )
278                    self.process.kill()  # This should definitely end the process
279            self.clear()
280
281
282@dataclasses.dataclass
283class TuningProcessPool:
284    """
285    Maintains a pool of TuningProcesses to benchmark kernels in parallel
286    across devices. By default, we create one TuningProcess per device and
287    set the sub-process environment to make only that device visible.
288    """
289
290    processes: Optional[queue.Queue[TuningProcess]] = None
291    executor: Optional[ThreadPoolExecutor] = None
292
293    def initialize(self) -> None:
294        """
295        Start the child processes.
296        """
297        assert (self.processes is None) == (self.executor is None)
298        if self.processes is not None:
299            return
300
301        devices = self.get_device_list()
302        log.debug("Sub-process autotune device list: %s", devices)
303
304        # Launch the child processes and push a msg to "warm up"
305        self.processes = queue.Queue()
306        for device in devices:
307            p = TuningProcess(device=device)
308            p.initialize()
309            p.put(Ping())
310            self.processes.put(p)
311
312        # Wait for the initialization to finish
313        for p in self.processes.queue:
314            assert isinstance(p.get(result_timeout=None), Pong)
315
316        # Use a thread pool to manage distributing work to the subprocesses.
317        # Threads block on an available process, so it makes sense to match
318        # the number of threads with the number of devices.
319        self.executor = ThreadPoolExecutor(max_workers=len(devices))
320
321        # Register the exit handler for the parent process so it will terminate
322        # the child processes.
323        global EXIT_HANDLER_REGISTERED
324        if not EXIT_HANDLER_REGISTERED:
325            EXIT_HANDLER_REGISTERED = True
326            import atexit
327
328            atexit.register(self.terminate)
329
330    def get_device_list(self) -> Sequence[Optional[int]]:
331        """
332        Gather the list of devices to be used in the pool.
333        """
334        if not config.autotune_multi_device:
335            # Don't use multiple devices
336            return [None]
337
338        count = torch.cuda.device_count()
339
340        # If the user specified the visible devices in the env, use those.
341        if CUDA_VISIBLE_DEVICES in os.environ:
342            devices = [int(d) for d in os.environ[CUDA_VISIBLE_DEVICES].split(",")]
343            assert len(devices) <= count
344            return devices
345
346        return list(range(count))
347
348    def terminate(self) -> None:
349        """
350        Signal all child processes to terminate.
351        """
352        if self.executor is not None:
353            self.executor.shutdown()
354            self.executor = None
355
356        if self.processes is not None:
357            for p in self.processes.queue:
358                p.terminate()
359            for p in self.processes.queue:
360                p.wait()
361            self.processes = None
362
363    def target(self, choice: TritonTemplateCaller) -> float:
364        """
365        Entry point for the thread-pool helper threads: Wait for an open TuningProcess,
366        remove it from the queue, execute the benchmark in that subprocess, and return
367        the TuningProcess to the queue.
368        """
369        assert choice.bmreq is not None
370        assert self.processes is not None
371
372        process = self.processes.get()
373        process.put(choice.bmreq)
374        try:
375            return process.get(
376                config.max_autotune_subproc_result_timeout_seconds,
377                config.max_autotune_subproc_graceful_timeout_seconds,
378                config.max_autotune_subproc_terminate_timeout_seconds,
379            )
380        except queue.Empty:
381            warnings.warn(
382                f"Failed to benchmark choice '{choice}'. It will be ignored. "
383                "Please debug the root cause in case the choice can bring perf gains."
384            )
385            # set to INF so this choice will be ignored
386            return float("inf")
387        finally:
388            self.processes.put(process)
389
390    def benchmark(
391        self,
392        choices: List[TritonTemplateCaller],
393    ) -> Dict[TritonTemplateCaller, float]:
394        """
395        Benchmark each choice in a separate process.
396        """
397        assert self.processes is not None, "Tuning process pool is not initialized"
398        assert self.executor is not None
399
400        results = {}
401
402        # Use a ThreadExecutorPool to spread the work across the subprocesses and
403        # to grab subprocesses as soon as they're free.
404        for choice, result in zip(choices, self.executor.map(self.target, choices)):
405            results[choice] = result
406
407        return results
408
409
410tuning_pool = TuningProcessPool()
411
412
413LayoutOrBuffer = Union[ir.Layout, ir.Buffer]
414
415
416@dataclasses.dataclass
417class TensorMeta:
418    device: torch.device
419    dtype: torch.dtype
420    sizes: torch._prims_common.ShapeType
421    strides: torch._prims_common.StrideType
422    offset: int
423    name: Optional[str] = None
424
425    @classmethod
426    def from_irnodes(
427        cls, irnodes: Union[LayoutOrBuffer, Sequence[LayoutOrBuffer]]
428    ) -> Union[TensorMeta, List[TensorMeta]]:
429        if isinstance(irnodes, Sequence):
430            result: List[Any] = [cls.from_irnodes(x) for x in irnodes]
431            assert all(isinstance(x, TensorMeta) for x in result)
432            return result
433
434        node = irnodes
435        if isinstance(node, ir.Layout):
436            node = ir.Buffer("fake", node)
437
438        dtype = node.get_dtype()
439        assert dtype is not None
440
441        return TensorMeta(
442            device=node.get_device(),
443            dtype=dtype,
444            sizes=V.graph.sizevars.size_hints(
445                node.get_size(),
446                fallback=config.unbacked_symint_fallback,
447            ),
448            strides=V.graph.sizevars.size_hints(
449                node.get_stride(),
450                fallback=config.unbacked_symint_fallback,
451            ),
452            offset=V.graph.sizevars.size_hint(
453                node.get_layout().offset,
454                fallback=config.unbacked_symint_fallback,
455            ),
456            name=node.get_name(),
457        )
458
459    def to_tensor(self) -> torch.Tensor:
460        return rand_strided(
461            self.sizes,
462            self.strides,
463            device=self.device,
464            dtype=self.dtype,
465            extra_size=self.offset,
466        )
467
468
469@dataclasses.dataclass
470class BenchmarkRequest:
471    """
472    Only handle triton template benchmark for now. The extern kernel benchmark
473    can be done inside the same process since they usually don't cause crash.
474
475    Important: Instances of this class and subclasses have to be serializable
476    across process boundaries. Do not put CUDA Tensors in here!
477    """
478
479    def __init__(
480        self,
481        kernel_name: str,
482        input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
483        output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
484        extra_args: Iterable[Any],
485    ) -> None:
486        # the kernel name defined in the module
487        self.kernel_name = kernel_name
488
489        if isinstance(input_tensor_meta, TensorMeta):
490            input_tensor_meta = [input_tensor_meta]
491        self.input_tensor_meta = input_tensor_meta
492
493        if isinstance(output_tensor_meta, (tuple, list)):
494            assert len(output_tensor_meta) == 1
495            output_tensor_meta = output_tensor_meta[0]
496        self.output_tensor_meta = output_tensor_meta
497
498        self.extra_args = extra_args
499
500    def make_run_fn(
501        self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
502    ) -> Callable[[], None]:
503        raise NotImplementedError
504
505    def cleanup_run_fn(self) -> None:
506        pass
507
508    def do_bench(
509        self,
510        fn,
511        *input_tensors: torch.Tensor,
512        output_tensor: Optional[torch.Tensor] = None,
513    ) -> float:
514        raise NotImplementedError
515
516    def benchmark(
517        self,
518        *input_tensors: torch.Tensor,
519        output_tensor: Optional[torch.Tensor] = None,
520    ) -> float:
521        debug = log.isEnabledFor(logging.DEBUG)
522        if debug:
523            start_ts = time.time()
524
525        # create args and out tensor
526        if output_tensor is None:
527            assert len(input_tensors) == 0
528            input_tensors = tuple(x.to_tensor() for x in self.input_tensor_meta)
529            output_tensor = self.output_tensor_meta.to_tensor()
530
531        if debug:
532            create_tensor_elapse = time.time() - start_ts  # type: ignore[possibly-undefined]
533            start_ts = time.time()
534        try:
535            fn = self.make_run_fn(*input_tensors, output_tensor=output_tensor)
536        except NonzeroWorkspaceNotSupportedError:
537            # Skipping all ops with nonzero workspace requirements
538            log.info("Skipping op due to nonzero workspace requirement")
539            return float("inf")
540
541        if debug:
542            load_elapse = time.time() - start_ts  # type: ignore[possibly-undefined]
543            start_ts = time.time()
544
545        out = self.do_bench(fn, *input_tensors, output_tensor)
546
547        if debug:
548            bench_elapse = time.time() - start_ts  # type: ignore[possibly-undefined]
549            log.debug(
550                "InChildProcess %s: load %f, create tensor %f, bench %f",
551                str(self),
552                load_elapse,  # type: ignore[possibly-undefined]
553                create_tensor_elapse,  # type: ignore[possibly-undefined]
554                bench_elapse,
555            )
556        self.cleanup_run_fn()
557        return out
558
559
560class TestBenchmarkRequest(BenchmarkRequest):
561    """
562    Supports unit testing. Defined in this file so that the TuningProcess
563    sub-process knows how to unpickle these objects.
564    """
565
566    def __init__(self, value: Optional[float] = None) -> None:
567        self.value = value
568
569    def benchmark(
570        self, *input_tensors: torch.Tensor, output_tensor: Optional[torch.Tensor] = None
571    ) -> float:
572        if self.value is None:
573            raise Exception("Failed to run")  # noqa: TRY002
574        return self.value
575
576
577class GPUDeviceBenchmarkRequest(BenchmarkRequest):
578    def do_bench(
579        self,
580        fn,
581        *input_tensors: torch.Tensor,
582        output_tensor: Optional[torch.Tensor] = None,
583    ) -> float:
584        device_idx_set = {
585            tensor.device.index
586            for tensor in [*input_tensors, output_tensor]
587            if isinstance(tensor, torch.Tensor)
588            and tensor.is_cuda
589            and tensor.device.index is not None
590        }
591        assert len(device_idx_set) <= 1, f"Can not mix devices {device_idx_set}"
592        if len(device_idx_set) == 1:
593            device_idx = next(iter(device_idx_set))
594        else:
595            device_idx = torch.cuda.current_device()
596
597        with torch.cuda.device(device_idx):
598            out = benchmarker.benchmark_gpu(fn)
599            torch.cuda.synchronize()  # shake out any CUDA errors
600
601        return out
602
603
604class TritonBenchmarkRequest(GPUDeviceBenchmarkRequest):
605    # Important: Instances of this class have to be serializable
606    # across process boundaries. Do not put CUDA Tensors in here!
607    def __init__(
608        self,
609        kernel_name: str,
610        input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
611        output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
612        extra_args: Iterable[Any],
613        module_path: str,  # the path of the module defining the triton kernel
614        module_cache_key: str,
615        grid: List[int],
616        num_stages: int,
617        num_warps: int,
618        matrix_instr_nonkdim: int = 0,  # only used for hip to choose the shape of mfma instruction.
619    ) -> None:
620        super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args)
621        self.module_path = module_path
622        self.module_cache_key = module_cache_key
623        self.grid = grid
624        self.num_stages = num_stages
625        self.num_warps = num_warps
626        self.matrix_instr_nonkdim = matrix_instr_nonkdim
627
628    def make_run_fn(
629        self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
630    ) -> Callable[[], None]:
631        mod = PyCodeCache.load_by_key_path(self.module_cache_key, self.module_path)
632        log.debug(
633            "benchmark module key: %s, path: %s",
634            self.module_cache_key,
635            self.module_path,
636        )
637
638        run_method = getattr(mod, self.kernel_name).run
639        extra_args = list(self.extra_args)
640
641        # Newer version of triton add warmup argument to JITFunction.run.
642        # This code handles backward-compatibility.
643        warmup_arg = {}
644        import inspect
645
646        if "warmup" in inspect.signature(run_method).parameters:
647            warmup_arg["warmup"] = False
648
649        from torch._C import _cuda_getCurrentRawStream as get_raw_stream
650
651        if torch.version.hip and self.matrix_instr_nonkdim != 0:
652            return functools.partial(
653                run_method,
654                *input_tensors,
655                output_tensor,
656                *self.extra_args,
657                grid=self.grid,
658                **warmup_arg,
659                stream=get_raw_stream(self.output_tensor_meta.device.index),
660            )
661        else:
662            return functools.partial(
663                run_method,
664                *input_tensors,
665                output_tensor,
666                *self.extra_args,
667                grid=self.grid,
668                **warmup_arg,
669                stream=get_raw_stream(self.output_tensor_meta.device.index),
670            )
671
672    def precompile(self):
673        mod = PyCodeCache.load_by_key_path(self.module_cache_key, self.module_path)
674        getattr(mod, self.kernel_name).precompile()
675
676    def __str__(self) -> str:
677        return f"{self.kernel_name=}, {self.module_path=}, {self.module_cache_key=}"
678
679
680class CUDABenchmarkRequest(GPUDeviceBenchmarkRequest):
681    # Important: Instances of this class have to be serializable
682    # across process boundaries. Do not put CUDA Tensors in here!
683
684    def __init__(
685        self,
686        kernel_name: str,
687        input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
688        output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
689        extra_args: Iterable[Any],
690        source_code: str,
691    ) -> None:
692        super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args)
693        self.source_code = source_code
694        self.workspace_size: int = 0
695        self.workspace: Optional[torch.Tensor] = None
696        self.DLL: Optional[DLLWrapper] = None
697        self._workspace_size_updated = False
698        self.hash_key: str = ""
699        self.source_file: str = ""
700        self.hash_key, self.source_file = CUDACodeCache.write(self.source_code, "so")
701
702    def precompile(self):
703        # Prepopulate CUDACodeCache
704        # may happen in separate Threadpool
705        log.debug("Precompiling %s", self)
706        CUDACodeCache.compile(self.source_code, "so")
707        log.debug("Done precompiling %s", self)
708
709    def make_run_fn(
710        self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
711    ) -> Callable[[], None]:
712        self.ensure_dll_loaded()
713        self.update_workspace_size()
714        args = [
715            c_void_p(tensor.data_ptr())
716            for tensor in list(input_tensors) + [output_tensor]
717        ]
718        log.debug(
719            "make_run_fn: self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s",
720            self.kernel_name,
721            self.source_file,
722            self.hash_key,
723            self.DLL,
724            args,
725            self.extra_args,
726        )
727        stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream)
728        run_method = getattr(self.DLL, self.kernel_name)
729        workspace_ptr = c_void_p(0)
730        if self.workspace_size > 0:
731            self.workspace = torch.zeros(
732                (self.workspace_size + 7) // 8,
733                dtype=torch.float64,
734                device=output_tensor.device,
735            )
736            workspace_ptr = c_void_p(self.workspace.data_ptr())
737
738        # Generate partial function.
739        return functools.partial(
740            run_method,
741            *args,
742            *self.extra_args,
743            None,  # null workspace size ptr
744            workspace_ptr,  # set workspace ptr,
745            stream_ptr,
746        )
747
748    def update_workspace_size(self) -> None:
749        if self._workspace_size_updated:
750            return
751        self.ensure_dll_loaded()
752        unique_input_count = len({meta.name for meta in self.input_tensor_meta})
753        args = [c_void_p(None) for _ in range(unique_input_count + 1)]
754        stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream)
755
756        run_method = getattr(self.DLL, self.kernel_name)
757        # Retrieve workspace_size and initialize workspace.
758        c_workspace_size = c_size_t()
759        run_method(
760            *args,  # input ptrs and output ptrs
761            *self.extra_args,
762            byref(
763                c_workspace_size
764            ),  # set workspace size ptr to retrieve workspace size
765            None,  # null workspace ptr
766            stream_ptr,
767        )
768        torch.cuda.synchronize()  # shake out any CUDA errors
769        self.workspace_size = c_workspace_size.value
770        log.debug(
771            "update_workspace_size called: new workspace size=%d, self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s",  # noqa: B950
772            self.workspace_size,
773            self.kernel_name,
774            self.source_file,
775            self.hash_key,
776            self.DLL,
777            args,
778            self.extra_args,
779        )
780        self._workspace_size_updated = True
781
782    def ensure_dll_loaded(self):
783        if self.DLL is None:
784            self.DLL, self.hash_key, self.source_file = CUDACodeCache.load(
785                self.source_code, "so"
786            )
787
788    def cleanup_run_fn(self) -> None:
789        if self.DLL is not None:
790            self.DLL.close()
791        self.workspace = None
792
793    def __str__(self) -> str:
794        return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}"
795
796
797class CPUDeviceBenchmarkRequest(BenchmarkRequest):
798    def do_bench(
799        self,
800        fn,
801        *input_tensors: torch.Tensor,
802        output_tensor: Optional[torch.Tensor] = None,
803    ) -> float:
804        return benchmarker.benchmark_cpu(fn)
805
806
807class CppBenchmarkRequest(CPUDeviceBenchmarkRequest):
808    # Important: Instances of this class have to be serializable
809    # across process boundaries. Do not put Tensors in here!
810
811    def __init__(
812        self,
813        kernel_name: str,
814        input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
815        output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
816        extra_args: Iterable[Any],
817        source_code: str,
818    ) -> None:
819        super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args)
820        self.source_code = source_code
821        self.hash_key = get_hash(source_code)
822        self.DLL: Optional[Union[CDLL, ModuleType]] = None
823
824    def precompile(self):
825        # Prepopulate CppCodeCache
826        # may happen in separate Threadpool
827        log.debug("Precompiling %s", self)
828        CppCodeCache.load(self.source_code, cuda=False)
829        log.debug("Done precompiling %s", self)
830
831    def make_run_fn(
832        self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
833    ) -> Callable[[], None]:
834        # TODO(jgong5): use CppPythonBindingsCodeCache for better binding perf
835        self.DLL = CppCodeCache.load(self.source_code, cuda=False)
836        args = [tensor.data_ptr() for tensor in list(input_tensors) + [output_tensor]]
837        log.debug(
838            "make_run_fn: self.kernel_name=%s, self.DLL=%s, args=%s, self.extra_args=%s",
839            self.kernel_name,
840            self.DLL,
841            args,
842            self.extra_args,
843        )
844        run_method = getattr(self.DLL, self.kernel_name)
845        # Assume only size with type ctypes.c_ulonglong in extra_args
846        assert all(isinstance(arg, ctypes.c_ulonglong) for arg in self.extra_args)
847        run_method.argtypes = [ctypes.c_ulonglong] * (
848            len(args) + len(list(self.extra_args))
849        )
850
851        # Generate partial function.
852        return functools.partial(
853            run_method,
854            *args,
855            *self.extra_args,
856        )
857
858    def cleanup_run_fn(self) -> None:
859        if self.DLL is not None:
860            """
861            Check close attr due to it crash on Windows.
862            """
863            if hasattr(self.DLL, "close"):
864                self.DLL.close()
865
866    def __str__(self) -> str:
867        return f"{self.kernel_name=}"
868
869
870def benchmark_in_sub_process(
871    choices: List[TritonTemplateCaller],
872) -> Dict[TritonTemplateCaller, float]:
873    """
874    Do benchmarking in a subprocess and return the perf number (latency).
875    """
876    return tuning_pool.benchmark(choices)
877