xref: /aosp_15_r20/external/pytorch/torch/_inductor/codegen/debug_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4import functools
5import logging
6from enum import Enum
7from typing import List, Optional
8
9from torch import dtype as torch_dtype
10
11from .. import config
12from ..virtualized import V
13from .multi_kernel import MultiKernel
14
15
16log = logging.getLogger(__name__)
17
18
19# AOTI debug printing related configs
20class IntermediateValueDebuggingLevel(Enum):
21    # OFF: No intermediate tensor value debug info will be printed or saved.
22    OFF = "0"
23    # LEVEL 1: Save all intermediate tensor values to individual `.pt` files. No debug printing will be displayed.
24    SAVE_ONLY = "1"
25    # LEVEL 2: Print all intermediate tensor values by default to the console. No debug saving will be performed.
26    PRINT_ONLY = "2"
27
28
29class DebugPrinterManager:
30    def __init__(
31        self,
32        debug_printer_level,
33        args_to_print_or_save: Optional[List[str]] = None,
34        kernel_name: str = "",
35        kernel=None,
36        arg_signatures: Optional[List[type]] = None,
37    ):
38        self.debug_printer_level = IntermediateValueDebuggingLevel(debug_printer_level)
39        if args_to_print_or_save is None:
40            args_to_print_or_save = []
41        self.args_to_print_or_save = args_to_print_or_save
42        self.kernel_name = kernel_name
43        self.arg_signatures: Optional[List[type]] = None
44        self.kernel = kernel
45        self.filtered_kernel_names_to_print = self._get_debug_filtered_kernel_names()
46
47    def __enter__(self):
48        self._perform_debug_print_or_save_helper(
49            self.args_to_print_or_save,
50            self.kernel_name,
51            before_launch=True,
52            arg_signatures=self.arg_signatures,
53        )
54
55    def __exit__(self, args_to_print_or_save, kernel_name, arg_signatures):
56        self._perform_debug_print_or_save_helper(
57            args_to_print_or_save,
58            kernel_name,
59            before_launch=False,
60            arg_signatures=arg_signatures,
61        )
62
63    def _perform_debug_print_or_save_helper(
64        self,
65        args_to_print_or_save,
66        kernel_name,
67        before_launch,
68        arg_signatures: Optional[List[type]] = None,
69    ):
70        if self.debug_printer_level == IntermediateValueDebuggingLevel.OFF:
71            return
72        if self.debug_printer_level == IntermediateValueDebuggingLevel.SAVE_ONLY:
73            # by default save all the tensor values before launch
74            self.codegen_intermediate_tensor_value_save(
75                self.args_to_print_or_save,
76                self.kernel_name,
77                before_launch,
78                arg_signatures=self.arg_signatures,
79            )
80        if self.debug_printer_level == IntermediateValueDebuggingLevel.PRINT_ONLY:
81            # by default print all the tensor values before launch
82            self.codegen_intermediate_tensor_value_print(
83                self.args_to_print_or_save,
84                self.kernel_name,
85                before_launch,
86                arg_signatures=self.arg_signatures,
87            )
88
89    @functools.lru_cache  # noqa: B019
90    def _get_debug_filtered_kernel_names(self) -> List[str]:
91        if config.aot_inductor.filtered_kernel_names is None:
92            return []
93        return [
94            x.strip()
95            for x in config.aot_inductor.filtered_kernel_names.lower().split(",")
96        ]
97
98    def set_printer_args(
99        self,
100        args_to_print_or_save: List[str],
101        kernel_name: str,
102        arg_signatures: Optional[List[type]],
103        kernel,
104    ):
105        # Note: MultiKernel debug printing is not supported for now
106        if isinstance(kernel, MultiKernel):
107            log.info(
108                "MultiKernel type is not supported in AOTI debug printer tool yet."
109            )
110            self.debug_printer_level = IntermediateValueDebuggingLevel.OFF
111        self.args_to_print_or_save = args_to_print_or_save
112        self.kernel_name = kernel_name
113        self.arg_signatures = arg_signatures
114        self.kernel = kernel
115
116    def codegen_intermediate_tensor_value_save(
117        self,
118        args_to_save,
119        kernel_name,
120        before_launch=True,
121        arg_signatures: Optional[List[type]] = None,
122    ) -> None:
123        for i, arg in enumerate(args_to_save):
124            if arg_signatures is not None and not isinstance(
125                arg_signatures[i], torch_dtype
126            ):
127                # infer from the arg data type (has torch.dtype) to see if it is a tensor type
128                continue
129            launch_prefix = "before_launch" if before_launch else "after_launch"
130            if V.graph.cpp_wrapper:
131                if config.abi_compatible:
132                    V.graph.wrapper_code.writeline(
133                        f'aoti_torch_save_tensor_handle({arg}, "{arg}", "{launch_prefix}", "{kernel_name}");'
134                    )
135                else:
136                    # TODO: add non-abi compatible mode debug printing info
137                    pass
138            else:
139                # currently, not cpp wrapper codegen mode not supported.
140                pass
141
142    def codegen_intermediate_tensor_value_print(
143        self,
144        args_to_print,
145        kernel_name,
146        before_launch=True,
147        arg_signatures: Optional[List[type]] = None,
148    ) -> None:
149        for i, arg in enumerate(args_to_print):
150            if arg_signatures is not None and not isinstance(
151                arg_signatures[i], torch_dtype
152            ):
153                # infer from the arg data type (has torch.dtype) to see if it is a tensor type
154                continue
155            if self.debug_printer_level == IntermediateValueDebuggingLevel.PRINT_ONLY:
156                # when debug printing is enabled i.e. IntermediateValueDebuggingLevel.PRINT_ONLY,
157                # check if filtered kernel name list is provided
158                if (
159                    len(self.filtered_kernel_names_to_print) > 0
160                    and kernel_name not in self.filtered_kernel_names_to_print
161                ):
162                    continue
163
164            launch_prefix = "before_launch" if before_launch else "after_launch"
165            if V.graph.cpp_wrapper:
166                if config.abi_compatible:
167                    V.graph.wrapper_code.writeline(
168                        f'aoti_torch_print_tensor_handle({arg}, "{launch_prefix} - {kernel_name} - {arg}");'
169                    )
170                else:
171                    # TODO: add non-abi compatible mode debug printing info
172                    pass
173            else:
174                line = f"print('{launch_prefix} - {kernel_name} - {arg}', {arg})"
175                V.graph.wrapper_code.writeline(line)
176