1# mypy: ignore-errors 2 3import dataclasses 4import functools 5import logging 6from importlib import import_module 7from typing import Any, List, Optional 8 9import torch 10from functorch.compile import min_cut_rematerialization_partition 11from torch import _guards 12from torch._functorch import config as functorch_config 13from torch._functorch.compilers import ts_compile 14 15from .common import aot_autograd 16from .registry import register_debug_backend as register_backend 17 18 19log = logging.getLogger(__name__) 20 21 22""" 23This file contains TorchDynamo backends intended for debugging uses. 24""" 25 26 27@register_backend 28def eager(gm, fake_tensor_inputs, **kwargs): 29 if kwargs: 30 log.warning("eager backend ignoring extra kwargs %s", kwargs) 31 return gm.forward 32 33 34@register_backend 35def eager_noexcept(gm, fake_tensor_inputs, **kwargs): 36 if kwargs: 37 log.warning("eager_noexcept backend ignoring extra kwargs %s", kwargs) 38 39 # This backend is intended to check that dynamo-generated GraphModules 40 # do not cause errors. 41 def inner(*args): 42 try: 43 return gm(*args) 44 except Exception as e: 45 raise torch._dynamo.exc.TorchDynamoException( 46 "Unexpected exception when running generated GraphModule" 47 ) from e 48 49 return inner 50 51 52@register_backend 53def pre_dispatch_eager(gm, fake_tensor_inputs, **kwargs): 54 if kwargs: 55 log.warning("pre_dispatch_eager backend ignoring extra kwargs %s", kwargs) 56 57 from torch.fx.experimental.proxy_tensor import make_fx 58 59 def runnable_gm(*args): 60 return torch.fx.Interpreter(gm).run(*args) 61 62 pre_dispatch_gm = make_fx(runnable_gm, pre_dispatch=True)(*fake_tensor_inputs) 63 pre_dispatch_gm.print_readable() 64 65 return pre_dispatch_gm 66 67 68@register_backend 69def eager_debug(gm, fake_tensor_inputs, **kwargs): 70 if kwargs: 71 log.warning("eager_debug backend ignoring extra kwargs %s", kwargs) 72 73 from torch._subclasses.schema_check_mode import SchemaCheckMode 74 75 # We could add more debugging bits here. 76 # Right now, this backend can be used to check for and error on 77 # custom dispatcher ops that have incorrect schemas. 78 def inner(*args): 79 with SchemaCheckMode(): 80 return torch.fx.Interpreter(gm).run(*args) 81 82 return inner 83 84 85@register_backend(name="ts") 86def torchscript(gm, fake_tensor_inputs): 87 return torch.jit.script(gm) 88 89 90# used boxed call to discard inputs when they are no longer needed 91def boxed_nop(fx_g, example_inputs): 92 def run(args): 93 return torch.fx.Interpreter(fx_g).boxed_run(args) 94 95 run._boxed_call = True 96 return run 97 98 99# Useful for debugging purpose 100# aot_eager uses AOT Autograd backend with nop compiler. It is helpful in debugging. 101aot_eager = aot_autograd( 102 fw_compiler=boxed_nop, 103 partition_fn=min_cut_rematerialization_partition, 104 keep_inference_input_mutations=True, 105) 106register_backend(name="aot_eager", compiler_fn=aot_eager) 107 108aot_eager_default_partitioner = aot_autograd( 109 fw_compiler=boxed_nop, keep_inference_input_mutations=True 110) 111register_backend( 112 name="aot_eager_default_partitioner", compiler_fn=aot_eager_default_partitioner 113) 114 115 116# Uses TorchInductor AOT Autograd decomps and partitioner to isolate aot vs 117# inductor problems. 118# aot_eager_decomp_partition just replaces the inductor compiler with nop to help 119# isolate inductor vs aot_eager errors 120def aot_eager_decomp_partition(gm, fake_tensor_inputs, **kwargs): 121 if kwargs: 122 log.warning( 123 "aot_eager_decomp_partition backend ignoring extra kwargs %s", kwargs 124 ) 125 126 with functorch_config.patch(unlift_effect_tokens=True): 127 return aot_autograd( 128 # these are taken from memory_efficient_fusion() 129 fw_compiler=boxed_nop, 130 bw_compiler=boxed_nop, 131 # NB: lambda here is to delay import of inductor 132 decompositions=lambda: import_module( 133 "torch._inductor.compile_fx" 134 ).select_decomp_table(), 135 partition_fn=functools.partial( 136 min_cut_rematerialization_partition, compiler="inductor" 137 ), 138 )(gm, fake_tensor_inputs) 139 140 141register_backend( 142 name="aot_eager_decomp_partition", compiler_fn=aot_eager_decomp_partition 143) 144 145 146# AOT Autograd with torchscript backend. Default partitioner. 147# aot_ts uses torchscript backend. We can use this with both nnc and nvfuser 148# by using the relevant fuser with torch.jit.fuser(...) 149aot_ts = aot_autograd(fw_compiler=ts_compile) 150register_backend(name="aot_ts", compiler_fn=aot_ts) 151 152# These buggy backends are used for inducing bugs so that we can test 153# our repro extraction / minifier scripts 154 155 156class ReluCompileError(Exception): 157 pass 158 159 160class TestingOnlyCompileError(Exception): 161 pass 162 163 164@register_backend 165def relu_compile_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs): 166 for node in gm.graph.nodes: 167 if node.target == torch.relu: 168 raise ReluCompileError 169 return gm 170 171 172@register_backend 173def relu_runtime_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs): 174 for node in gm.graph.nodes: 175 if node.target == torch.relu: 176 node.target = torch._assert 177 node.args = (False, "ReluRuntimeError") 178 gm.recompile() 179 return gm 180 181 182@register_backend 183def relu_accuracy_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs): 184 for node in gm.graph.nodes: 185 if node.target == torch.relu: 186 node.target = torch.add 187 node.args = (node.args[0], 1) 188 gm.recompile() 189 190 return gm 191 192 193@register_backend 194def non_leaf_compile_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs): 195 # Require at least one non-trivial thing in the graph, 196 # see https://github.com/pytorch/pytorch/issues/102898 197 for node in gm.graph.nodes: 198 if node.op == "call_function": 199 break 200 else: 201 return gm 202 for t in example_inputs: 203 if not t.is_leaf: 204 raise TestingOnlyCompileError 205 return gm 206 207 208@dataclasses.dataclass 209class ExplainOutput: 210 """ 211 This is the output of :func:`torch._dynamo.explain()` 212 There is no reason to create this class directly. 213 """ 214 215 graphs: List[torch.fx.GraphModule] 216 graph_count: int 217 graph_break_count: int 218 break_reasons: List[ 219 Any 220 ] # Type is GraphCompileReason but doesn't matter for this purpose 221 op_count: int 222 ops_per_graph: Optional[List[torch.fx.Node]] = None 223 out_guards: Optional[List[_guards.Guard]] = None 224 compile_times: Optional[str] = None 225 226 def __str__(self) -> str: 227 output = f"Graph Count: {self.graph_count}\n" 228 output += f"Graph Break Count: {self.graph_break_count}\n" 229 output += f"Op Count: {self.op_count}\n" 230 231 output += "Break Reasons:\n" 232 for idx, break_reason in enumerate(self.break_reasons): 233 output += f" Break Reason {idx+1}:\n" 234 output += f" Reason: {break_reason.reason}\n" 235 output += " User Stack:\n" 236 for frame_summary in break_reason.user_stack: 237 output += f" {frame_summary}\n" 238 239 if self.ops_per_graph is not None: 240 output += "Ops per Graph:\n" 241 for idx, ops in enumerate(self.ops_per_graph): 242 output += f" Ops {idx+1}:\n" 243 for op in ops: 244 output += f" {op}\n" 245 246 if self.out_guards is not None: 247 output += "Out Guards:\n" 248 for i, guard in enumerate(self.out_guards): 249 output += f" Guard {i+1}:\n" 250 output += f" {str(guard)}" 251 252 if self.compile_times is not None: 253 output += f"Compile Times: {self.compile_times}\n" 254 return output 255 256 257def _explain_graph_detail( 258 gm: torch.fx.GraphModule, graphs, op_count, ops_per_graph, break_reasons 259): 260 """ 261 This function is a utility which processes a torch.fx.GraphModule and 262 accumulates information about its ops, graph breaks, and other details. It 263 is intended to be used by the ExplainWithBackend class and 264 `torch._dynamo.explain()` to provide details from Dynamo's graph capture. 265 266 Parameters: 267 gm (torch.fx.GraphModule): The GraphModule to be processed. 268 graphs (list): A list that accumulates all the GraphModules processed. 269 op_count (int): The total count of operations in all GraphModules processed so far. 270 ops_per_graph (list): A list that accumulates the operations of each GraphModule. 271 break_reasons (list): A list that accumulates the reasons for breaks in each GraphModule. 272 273 Returns: 274 tuple: A tuple containing the processed GraphModule, the updated lists of graphs, 275 operations per graph, and break reasons, and the updated operation count. 276 """ 277 graphs.append(gm) 278 ops = [node.target for node in gm.graph.nodes if node.op == "call_function"] 279 op_count += len(ops) 280 ops_per_graph.append(ops) 281 if gm.compile_subgraph_reason.graph_break: 282 break_reasons.append(gm.compile_subgraph_reason) 283 284 return gm, graphs, op_count, ops_per_graph, break_reasons 285 286 287class ExplainWithBackend: 288 """ 289 This class is intended to be used as a backend for `torch.compile`. It is 290 composable with other backends. When used in this way, it accumulates 291 information about graph breaks, ops, and other info and provides a string 292 representation summarizing this information. 293 294 Attributes: 295 backend (str): The name of the backend to use for optimization. 296 graphs (list): A list of the graphs captured by TorchDynamo. 297 op_count (int): The total number of operations in all optimized graphs. 298 break_reasons (list): A list of graph break reasons with stack traces. 299 300 Example Usage: 301 def fn(x): 302 x = torch.sigmoid(x) 303 return x 304 305 torch._dynamo.reset() 306 eb = ExplainWithBackend("inductor") 307 optimized_fn = torch.compile(fn, backend=eb) 308 result = optimized_fn(torch.randn(5)) 309 print(eb.output()) 310 """ 311 312 def __init__(self, backend) -> None: 313 from .registry import lookup_backend 314 315 self.backend = lookup_backend(backend) 316 self.graphs = [] 317 self.op_count = 0 318 self.break_reasons = [] 319 320 def __call__(self, gm: torch.fx.GraphModule, example_inputs): 321 gm, self.graphs, self.op_count, _, self.break_reasons = _explain_graph_detail( 322 gm, self.graphs, self.op_count, [], self.break_reasons 323 ) 324 return self.backend(gm, example_inputs) 325 326 def output(self) -> ExplainOutput: 327 graph_count = len(self.graphs) 328 output = ExplainOutput( 329 self.graphs, 330 graph_count, 331 graph_count - 1, 332 self.break_reasons, 333 self.op_count, 334 ) 335 336 return output 337