xref: /aosp_15_r20/external/pytorch/torch/_dynamo/backends/debugging.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import dataclasses
4import functools
5import logging
6from importlib import import_module
7from typing import Any, List, Optional
8
9import torch
10from functorch.compile import min_cut_rematerialization_partition
11from torch import _guards
12from torch._functorch import config as functorch_config
13from torch._functorch.compilers import ts_compile
14
15from .common import aot_autograd
16from .registry import register_debug_backend as register_backend
17
18
19log = logging.getLogger(__name__)
20
21
22"""
23This file contains TorchDynamo backends intended for debugging uses.
24"""
25
26
27@register_backend
28def eager(gm, fake_tensor_inputs, **kwargs):
29    if kwargs:
30        log.warning("eager backend ignoring extra kwargs %s", kwargs)
31    return gm.forward
32
33
34@register_backend
35def eager_noexcept(gm, fake_tensor_inputs, **kwargs):
36    if kwargs:
37        log.warning("eager_noexcept backend ignoring extra kwargs %s", kwargs)
38
39    # This backend is intended to check that dynamo-generated GraphModules
40    # do not cause errors.
41    def inner(*args):
42        try:
43            return gm(*args)
44        except Exception as e:
45            raise torch._dynamo.exc.TorchDynamoException(
46                "Unexpected exception when running generated GraphModule"
47            ) from e
48
49    return inner
50
51
52@register_backend
53def pre_dispatch_eager(gm, fake_tensor_inputs, **kwargs):
54    if kwargs:
55        log.warning("pre_dispatch_eager backend ignoring extra kwargs %s", kwargs)
56
57    from torch.fx.experimental.proxy_tensor import make_fx
58
59    def runnable_gm(*args):
60        return torch.fx.Interpreter(gm).run(*args)
61
62    pre_dispatch_gm = make_fx(runnable_gm, pre_dispatch=True)(*fake_tensor_inputs)
63    pre_dispatch_gm.print_readable()
64
65    return pre_dispatch_gm
66
67
68@register_backend
69def eager_debug(gm, fake_tensor_inputs, **kwargs):
70    if kwargs:
71        log.warning("eager_debug backend ignoring extra kwargs %s", kwargs)
72
73    from torch._subclasses.schema_check_mode import SchemaCheckMode
74
75    # We could add more debugging bits here.
76    # Right now, this backend can be used to check for and error on
77    # custom dispatcher ops that have incorrect schemas.
78    def inner(*args):
79        with SchemaCheckMode():
80            return torch.fx.Interpreter(gm).run(*args)
81
82    return inner
83
84
85@register_backend(name="ts")
86def torchscript(gm, fake_tensor_inputs):
87    return torch.jit.script(gm)
88
89
90# used boxed call to discard inputs when they are no longer needed
91def boxed_nop(fx_g, example_inputs):
92    def run(args):
93        return torch.fx.Interpreter(fx_g).boxed_run(args)
94
95    run._boxed_call = True
96    return run
97
98
99# Useful for debugging purpose
100# aot_eager uses AOT Autograd backend with nop compiler. It is helpful in debugging.
101aot_eager = aot_autograd(
102    fw_compiler=boxed_nop,
103    partition_fn=min_cut_rematerialization_partition,
104    keep_inference_input_mutations=True,
105)
106register_backend(name="aot_eager", compiler_fn=aot_eager)
107
108aot_eager_default_partitioner = aot_autograd(
109    fw_compiler=boxed_nop, keep_inference_input_mutations=True
110)
111register_backend(
112    name="aot_eager_default_partitioner", compiler_fn=aot_eager_default_partitioner
113)
114
115
116# Uses TorchInductor AOT Autograd decomps and partitioner to isolate aot vs
117# inductor problems.
118# aot_eager_decomp_partition just replaces the inductor compiler with nop to help
119# isolate inductor vs aot_eager errors
120def aot_eager_decomp_partition(gm, fake_tensor_inputs, **kwargs):
121    if kwargs:
122        log.warning(
123            "aot_eager_decomp_partition backend ignoring extra kwargs %s", kwargs
124        )
125
126    with functorch_config.patch(unlift_effect_tokens=True):
127        return aot_autograd(
128            # these are taken from memory_efficient_fusion()
129            fw_compiler=boxed_nop,
130            bw_compiler=boxed_nop,
131            # NB: lambda here is to delay import of inductor
132            decompositions=lambda: import_module(
133                "torch._inductor.compile_fx"
134            ).select_decomp_table(),
135            partition_fn=functools.partial(
136                min_cut_rematerialization_partition, compiler="inductor"
137            ),
138        )(gm, fake_tensor_inputs)
139
140
141register_backend(
142    name="aot_eager_decomp_partition", compiler_fn=aot_eager_decomp_partition
143)
144
145
146# AOT Autograd with torchscript backend. Default partitioner.
147# aot_ts uses torchscript backend. We can use this with both nnc and nvfuser
148# by using the relevant fuser with torch.jit.fuser(...)
149aot_ts = aot_autograd(fw_compiler=ts_compile)
150register_backend(name="aot_ts", compiler_fn=aot_ts)
151
152# These buggy backends are used for inducing bugs so that we can test
153# our repro extraction / minifier scripts
154
155
156class ReluCompileError(Exception):
157    pass
158
159
160class TestingOnlyCompileError(Exception):
161    pass
162
163
164@register_backend
165def relu_compile_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
166    for node in gm.graph.nodes:
167        if node.target == torch.relu:
168            raise ReluCompileError
169    return gm
170
171
172@register_backend
173def relu_runtime_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
174    for node in gm.graph.nodes:
175        if node.target == torch.relu:
176            node.target = torch._assert
177            node.args = (False, "ReluRuntimeError")
178    gm.recompile()
179    return gm
180
181
182@register_backend
183def relu_accuracy_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
184    for node in gm.graph.nodes:
185        if node.target == torch.relu:
186            node.target = torch.add
187            node.args = (node.args[0], 1)
188    gm.recompile()
189
190    return gm
191
192
193@register_backend
194def non_leaf_compile_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
195    # Require at least one non-trivial thing in the graph,
196    # see https://github.com/pytorch/pytorch/issues/102898
197    for node in gm.graph.nodes:
198        if node.op == "call_function":
199            break
200    else:
201        return gm
202    for t in example_inputs:
203        if not t.is_leaf:
204            raise TestingOnlyCompileError
205    return gm
206
207
208@dataclasses.dataclass
209class ExplainOutput:
210    """
211    This is the output of :func:`torch._dynamo.explain()`
212    There is no reason to create this class directly.
213    """
214
215    graphs: List[torch.fx.GraphModule]
216    graph_count: int
217    graph_break_count: int
218    break_reasons: List[
219        Any
220    ]  # Type is GraphCompileReason but doesn't matter for this purpose
221    op_count: int
222    ops_per_graph: Optional[List[torch.fx.Node]] = None
223    out_guards: Optional[List[_guards.Guard]] = None
224    compile_times: Optional[str] = None
225
226    def __str__(self) -> str:
227        output = f"Graph Count: {self.graph_count}\n"
228        output += f"Graph Break Count: {self.graph_break_count}\n"
229        output += f"Op Count: {self.op_count}\n"
230
231        output += "Break Reasons:\n"
232        for idx, break_reason in enumerate(self.break_reasons):
233            output += f"  Break Reason {idx+1}:\n"
234            output += f"    Reason: {break_reason.reason}\n"
235            output += "    User Stack:\n"
236            for frame_summary in break_reason.user_stack:
237                output += f"      {frame_summary}\n"
238
239        if self.ops_per_graph is not None:
240            output += "Ops per Graph:\n"
241            for idx, ops in enumerate(self.ops_per_graph):
242                output += f"  Ops {idx+1}:\n"
243                for op in ops:
244                    output += f"    {op}\n"
245
246        if self.out_guards is not None:
247            output += "Out Guards:\n"
248            for i, guard in enumerate(self.out_guards):
249                output += f"  Guard {i+1}:\n"
250                output += f"    {str(guard)}"
251
252        if self.compile_times is not None:
253            output += f"Compile Times: {self.compile_times}\n"
254        return output
255
256
257def _explain_graph_detail(
258    gm: torch.fx.GraphModule, graphs, op_count, ops_per_graph, break_reasons
259):
260    """
261    This function is a utility which processes a torch.fx.GraphModule and
262    accumulates information about its ops, graph breaks, and other details. It
263    is intended to be used by the ExplainWithBackend class and
264    `torch._dynamo.explain()` to provide details from Dynamo's graph capture.
265
266    Parameters:
267        gm (torch.fx.GraphModule): The GraphModule to be processed.
268        graphs (list): A list that accumulates all the GraphModules processed.
269        op_count (int): The total count of operations in all GraphModules processed so far.
270        ops_per_graph (list): A list that accumulates the operations of each GraphModule.
271        break_reasons (list): A list that accumulates the reasons for breaks in each GraphModule.
272
273    Returns:
274        tuple: A tuple containing the processed GraphModule, the updated lists of graphs,
275               operations per graph, and break reasons, and the updated operation count.
276    """
277    graphs.append(gm)
278    ops = [node.target for node in gm.graph.nodes if node.op == "call_function"]
279    op_count += len(ops)
280    ops_per_graph.append(ops)
281    if gm.compile_subgraph_reason.graph_break:
282        break_reasons.append(gm.compile_subgraph_reason)
283
284    return gm, graphs, op_count, ops_per_graph, break_reasons
285
286
287class ExplainWithBackend:
288    """
289    This class is intended to be used as a backend for `torch.compile`. It is
290    composable with other backends. When used in this way, it accumulates
291    information about graph breaks, ops, and other info and provides a string
292    representation summarizing this information.
293
294    Attributes:
295        backend (str): The name of the backend to use for optimization.
296        graphs (list): A list of the graphs captured by TorchDynamo.
297        op_count (int): The total number of operations in all optimized graphs.
298        break_reasons (list): A list of graph break reasons with stack traces.
299
300    Example Usage:
301        def fn(x):
302            x = torch.sigmoid(x)
303            return x
304
305        torch._dynamo.reset()
306        eb = ExplainWithBackend("inductor")
307        optimized_fn = torch.compile(fn, backend=eb)
308        result = optimized_fn(torch.randn(5))
309        print(eb.output())
310    """
311
312    def __init__(self, backend) -> None:
313        from .registry import lookup_backend
314
315        self.backend = lookup_backend(backend)
316        self.graphs = []
317        self.op_count = 0
318        self.break_reasons = []
319
320    def __call__(self, gm: torch.fx.GraphModule, example_inputs):
321        gm, self.graphs, self.op_count, _, self.break_reasons = _explain_graph_detail(
322            gm, self.graphs, self.op_count, [], self.break_reasons
323        )
324        return self.backend(gm, example_inputs)
325
326    def output(self) -> ExplainOutput:
327        graph_count = len(self.graphs)
328        output = ExplainOutput(
329            self.graphs,
330            graph_count,
331            graph_count - 1,
332            self.break_reasons,
333            self.op_count,
334        )
335
336        return output
337