1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9import copy 10import json 11import traceback 12from contextlib import contextmanager 13from dataclasses import asdict, dataclass 14from typing import ( 15 Any, 16 Callable, 17 Dict, 18 Generator, 19 Iterable, 20 List, 21 Optional, 22 Set, 23 Tuple, 24 Union, 25) 26 27import executorch.extension.pytree as ex_pytree 28import torch 29import torch._dynamo as torchdynamo 30import torch.fx as fx 31 32import torch.fx._pytree as fx_pytree 33import torch.utils._pytree as pytree 34 35from executorch.exir.common import ( 36 extract_out_arguments, 37 format_schema_name, 38 no_dispatch, 39 setting_python_recursive_limit, 40) 41from executorch.exir.error import ExportError, ExportErrorType, InternalError 42from executorch.exir.graph_module import LeafValue 43from executorch.exir.operator.convert import is_out_variant 44from executorch.exir.types import ValueSpec 45 46from torch._C import _EnableTorchFunction, DisableTorchFunctionSubclass # @manual 47from torch._decomp import get_decompositions 48from torch._dynamo.guards import Guard 49from torch._functorch.eager_transforms import _maybe_unwrap_functional_tensor 50from torch.export import default_decompositions 51from torch.func import functionalize 52from torch.fx.operator_schemas import normalize_function 53from torch.utils._pytree import TreeSpec 54 55from typing_extensions import TypeAlias 56 57 58Value: TypeAlias = Union[ 59 LeafValue, 60 Tuple["Value", ...], 61 List["Value"], 62 Dict[str, "Value"], 63] 64 65torchdynamo_enabled = False 66 67 68def get_stacktrace() -> List[Dict[str, str]]: 69 """ 70 Get the current stacktrace (between trace() and __torch_dispatch__()) 71 Include the filename, function name, line number, and source code from the 72 start of the function to the given instruction. 73 74 Return: 75 A list of stacktraces for each instruction along with the source code 76 context surrounding each instruction 77 """ 78 79 stacktrace = traceback.extract_stack() 80 81 # The stacktrace typically looks like this: 82 # 83 # 1. I stack frames from the top level runner (e.g., the 84 # test suite runner) 85 # 2. J frames in executorch/exir/tracer.py setting up tracing 86 # (call this INIT_EXIR) 87 # 3. K frames in user model code (this is what we want to save!) 88 # 4. 1 frame in executorch/exir/tracer.py __torch_function__ 89 # returning to tracer (call this TRACE_EXIR) 90 # 5. H frames in executorch/exir/tracer.py AND torch/_tensor.py 91 # doing all of the internal tracer handling 92 # 93 # The PyE tests assert that executorch/exir/tracer.py never shows 94 # up in the user provided stack traces, so we must oblige them. 95 # 96 # Assumptions: 97 # - Reentrant tracing is not a thing. Thus, the first time 98 # executorch/exir/tracer.py shows up in the trace, we know 99 # THAT is the point at which we start tracing. (An alternative 100 # is that the tracer entry point could record the stack trace 101 # at this time, but I didn't do this.) 102 # 103 # Our plan is to do a miniature stack machine traversing these 104 # stack machines. 105 106 # Remove parts before the trace function and parts after entering 107 # __torch_dispatch__. Defaults to returning the entire stack trace. 108 init_exir_end = 0 109 trace_exir_start = None 110 # A miniature state machine, referring to the frame segments described 111 # above. The locations are closed-open interval. 112 FIND_INIT_EXIR_START, FIND_INIT_EXIR_END, FIND_TRACE_EXIR_START = range(3) 113 state = FIND_INIT_EXIR_START 114 for i, frame in enumerate(stacktrace): 115 if state == FIND_INIT_EXIR_START: 116 if "executorch/exir/tracer.py" in frame.filename: 117 state = FIND_INIT_EXIR_END 118 elif state == FIND_INIT_EXIR_END: 119 if "executorch/exir/tracer.py" not in frame.filename: 120 init_exir_end = i 121 state = FIND_TRACE_EXIR_START 122 elif state == FIND_TRACE_EXIR_START: 123 if "executorch/exir/tracer.py" in frame.filename: 124 trace_exir_start = i 125 break 126 127 stacktrace = stacktrace[init_exir_end:trace_exir_start] 128 129 # Get the source code from the errored line to it 130 contexts: List[str] = [] 131 for s in stacktrace: 132 try: 133 with open(s.filename) as file: 134 # pyre-fixme[6]: For 1st param expected `Union[SupportsTrunc, bytes, 135 # str, SupportsInt, SupportsIndex]` but got `Optional[int]`. 136 lineno = int(s.lineno) 137 # Get the source code 5 lines above/below the current instruction 138 file_contents = [ 139 str(index + 1) + line for index, line in enumerate(file.readlines()) 140 ] 141 file_contents_above = "".join( 142 file_contents[max(lineno - 5, 0) : lineno] 143 ) 144 file_contents_below = "".join( 145 file_contents[lineno : min(lineno + 5, len(file_contents))] 146 ) 147 context = ( 148 file_contents_above 149 + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n" 150 + file_contents_below 151 ) 152 contexts.append(context) 153 except FileNotFoundError: 154 contexts.append("<unknown file: unknown line>") 155 156 # torch.fx stack preservation logic expects strings to 157 # be passed around. Working with dictionary is lot easier 158 # to convert to string and vice versa. 159 frames: List[Dict[str, str]] = [] 160 for i, frame in enumerate(stacktrace): 161 frames.append( 162 { 163 "filename": str(frame.filename), 164 "lineno": str(frame.lineno), 165 "name": str(frame.name), 166 "line": str(frame.line), 167 "context": contexts[i], 168 } 169 ) 170 171 return frames 172 173 174def unwrap_functional(t: torch.Tensor) -> torch.Tensor: 175 assert isinstance(t, torch.Tensor) 176 return _maybe_unwrap_functional_tensor(t, reapply_views=False) 177 178 179def unwrap_proxy(t: LeafValue) -> Union[LeafValue, torch.fx.Proxy]: 180 if not isinstance(t, torch.Tensor): 181 return t 182 t = unwrap_functional(t) 183 return t.proxy if isinstance(t, PythonTensor) else t 184 185 186def single_return( 187 output: LeafValue, 188 proxy: torch.fx.Proxy, 189 wrapper: Callable[..., LeafValue], 190) -> LeafValue: 191 if isinstance(output, torch.Tensor): 192 return wrapper(output, proxy) 193 194 return output 195 196 197def tree_return( 198 outputs: Value, 199 proxy: torch.fx.Proxy, 200 wrapper: Callable[..., LeafValue], 201 meta_type: Callable[..., Iterable[ValueSpec]] = tuple, 202) -> Value: 203 i: int = 0 204 205 def wrap(o: LeafValue) -> LeafValue: 206 nonlocal i 207 ret = single_return(o, proxy[i], wrapper) 208 i += 1 209 return ret 210 211 return pytree.tree_map(wrap, outputs) 212 213 214class DummyProxy: 215 def __init__(self) -> None: 216 class DummyNode: 217 def __init__(self): 218 self.meta = {} 219 220 self.node = DummyNode() 221 222 def __getitem__(self, key: str) -> "DummyProxy": 223 return DummyProxy() 224 225 226class PythonTensor(torch.Tensor): 227 """ 228 A wrapper tensor subclass used in the DispatchTracer to keep track of 229 proxies to construct the FX graph. 230 231 Wrapping something in PythonTensor implicitly detaches gradients. If 232 something required grad, we will collect it as if it were a leaf. A 233 consequence of detaching in this way is you need to maintain a parameter 234 cache when translating tensors into PythonTensor, so you don't create 235 multiple copies of a gradient (they are aliased, but they would count as 236 independent leaves). An alternate strategy would be to avoid implicitly 237 detaching and instead "catch" gradients as they exit the PythonTensor 238 boundary. 239 """ 240 241 __slots__ = ["proxy", "is_immutable"] 242 243 @staticmethod 244 def __new__( 245 cls, elem: torch.Tensor, proxy: torch.fx.Proxy, is_immutable: bool = False 246 ) -> torch.Tensor: 247 # assert not elem.requires_grad or not torch.is_grad_enabled() 248 249 r = torch.Tensor._make_subclass(cls, elem, elem.requires_grad) 250 assert isinstance(r, PythonTensor) 251 r.is_immutable: bool = is_immutable 252 r.update_proxy(proxy) 253 return r 254 255 def update_proxy(self, proxy: torch.fx.Proxy) -> None: 256 self.proxy = proxy 257 258 def __repr__(self, *, tensor_contents: None = None) -> str: 259 with no_dispatch(): 260 return f"PythonTensor({self.as_subclass(torch.Tensor)})" 261 262 @classmethod 263 def __torch_function__( 264 cls, 265 # pyre-ignore: Missing parameter annotation [2] 266 func, 267 # pyre-ignore: Missing parameter annotation [2] 268 types, 269 args: Tuple[Value, ...] = (), 270 kwargs: Optional[Dict[str, Value]] = None, 271 ) -> Value: 272 if kwargs is None: 273 kwargs = {} 274 if torch.is_inference_mode_enabled(): 275 if func is torch.nn.functional.layer_norm: 276 args, kwargs = normalize_function(func, args, kwargs) # pyre-fixme[23] 277 input, normalized_shape = args 278 normalized_shape = list(normalized_shape) 279 return cls.__torch_dispatch__( 280 torch.ops.aten.layer_norm.default, 281 types, 282 (input, normalized_shape), 283 kwargs, 284 ) 285 elif func is torch.nn.functional.linear: 286 return cls.__torch_dispatch__( 287 torch.ops.aten.linear.default, types, args, kwargs 288 ) 289 with DisableTorchFunctionSubclass(): 290 return func(*args, **kwargs) 291 292 @classmethod 293 def __torch_dispatch__( # noqa: C901 294 cls, 295 func_overload: torch._ops.OpOverload, 296 # pyre-ignore: Missing parameter annotation [2] 297 types, 298 args: Tuple[Value, ...] = (), 299 kwargs: Optional[Dict[str, Value]] = None, 300 ) -> Value: 301 """ 302 This function is invoked every time an aten operation is called. 303 304 Args: 305 func_overload: The function that was called that invoked this 306 torch_dispatch call 307 types: 308 args: Arguments that were passed into the function. Each argument 309 has type PythonTensor. 310 kwargs: Keyword arguments that were passed into the function. Each 311 argument has type PythonTensor. 312 """ 313 func = func_overload.overloadpacket 314 315 kwargs = kwargs or {} 316 if is_out_variant(func._qualified_op_name, func_overload._overloadname): 317 out_args = extract_out_arguments(func_overload._schema, kwargs) 318 out_args_iter = [out_args] if not isinstance(out_args, list) else out_args 319 for out_arg_name, out_arg_val in out_args_iter: 320 if isinstance(out_arg_val, PythonTensor) and out_arg_val.is_immutable: 321 raise RuntimeError( 322 "Immutable tensor `{}` is potentially getting modified by {}".format( 323 out_arg_name, format_schema_name(func_overload._schema) 324 ) 325 ) 326 327 # pyre-fixme[16]: Module `pytree` has no attribute `tree_map`. 328 proxy_args = ex_pytree.tree_map(unwrap_proxy, args) 329 # pyre-fixme[16]: Module `pytree` has no attribute `tree_map`. 330 proxy_kwargs = ex_pytree.tree_map(unwrap_proxy, kwargs) 331 332 # Get the output of the function 333 g = _EnableTorchFunction() 334 try: 335 proxy_out = ( 336 func_overload(*proxy_args, **proxy_kwargs) 337 if DispatchTracer.get() or torchdynamo_enabled 338 # Disable node creation when no tracer is active. 339 else DummyProxy() 340 ) 341 finally: 342 del g 343 344 with no_dispatch(): 345 real_out = func_overload(*args, **kwargs) 346 347 # Kind of a hacky way to test if an op is in-place or not 348 if func.__name__[-1] == "_" and func.__name__[0] != "_": 349 if isinstance(args[0], PythonTensor): 350 args[0].proxy = proxy_out 351 352 if not torch.fx.traceback.has_preserved_node_meta(): 353 proxy_out.node.meta["stack_trace"] = json.dumps(get_stacktrace()) 354 355 # Wrap the output tensors with the PythonTensor subclass to propagate to 356 # future tracing 357 def wrap_with_proxy(e: LeafValue, proxy: torch.fx.Proxy) -> LeafValue: 358 # Some ops (like native_batch_norm_backward) return undefined tensors that get 359 # converted into None in python. 360 # As the function signature expects tensors, if we directly return these None 361 # tensors back to C++, we'll error. 362 if e is None: 363 e = torch.empty(()) 364 365 if isinstance(e, torch.Tensor): 366 return PythonTensor(e, proxy) 367 368 # Inplace and out-variant ops may return one of their arguments, which is already 369 # a PythonTensor. In this case, we need to update the PythonTensor's associated 370 # proxy to the newly created proxy. 371 if isinstance(e, PythonTensor): 372 e.update_proxy(proxy) 373 return e 374 375 return e 376 377 retval = None 378 if not isinstance(real_out, (list, tuple)): 379 retval = single_return(real_out, proxy_out, wrap_with_proxy) 380 else: 381 retval = tree_return(real_out, proxy_out, wrap_with_proxy, type(real_out)) 382 return retval 383 384 385@contextmanager 386def using_tracer(tracer: Optional["DispatchTracer"]) -> Generator[None, None, None]: 387 """ 388 Set the "current" global tracer within the scope of using_tracer 389 context manager. 390 391 Since various things we want to capture today with torch_dispatch 392 does not "trap" into dispatcher really (for example, cond() and 393 shape()), we need a separate singleton tracer exposed to user space 394 in addition to Dispatcher to trigger graph capturing. 395 """ 396 global TRACER 397 TRACER, prev = tracer, TRACER 398 try: 399 yield 400 finally: 401 TRACER = prev 402 403 404class DispatchTracer(fx.Tracer): 405 def __init__(self) -> None: 406 super().__init__() 407 self.root: torch.nn.Module = torch.nn.Module() 408 self.tensor_attrs: Dict[torch.Tensor, str] = {} 409 self.submodules: Dict[fx.GraphModule, str] = {} 410 411 def call_module( 412 self, 413 m: torch.nn.Module, 414 forward: Callable[..., Value], 415 args: Tuple[Value, ...], 416 kwargs: Dict[str, Value], 417 ) -> Value: 418 return forward(*args, **kwargs) 419 420 def _module_getattr( 421 self, attr: str, attr_val: Value, parameter_proxy_cache: Dict[str, torch.Tensor] 422 ) -> Value: 423 if isinstance(attr_val, torch.nn.Parameter): 424 for n, p in self.root.named_parameters(): 425 if attr_val is p: 426 if n not in parameter_proxy_cache: 427 proxy = self.create_proxy("get_attr", n, (), {}) 428 parameter_proxy_cache[n] = PythonTensor(attr_val, proxy) 429 return parameter_proxy_cache[n] 430 return attr_val 431 return attr_val 432 433 def create_arg(self, a: Value) -> torch.fx.Node: # noqa: C901 434 if isinstance(a, torch.nn.Parameter): 435 for n, p in self.root.named_parameters(): 436 if a is p: 437 return self.create_node("get_attr", n, (), {}) 438 qualname: Optional[str] = None 439 440 if not qualname: 441 i = 0 442 while True: 443 qualname = f"_param_constant{i}" 444 if not hasattr(self.root, qualname): 445 break 446 i += 1 447 setattr(self.root, qualname, a) 448 449 return self.create_node("get_attr", qualname, (), {}) 450 451 if isinstance(a, torch.Tensor): 452 qualname: Optional[str] = self.tensor_attrs.get(a) 453 454 if not qualname: 455 i = 0 456 while True: 457 qualname = f"_tensor_constant{i}" 458 if not hasattr(self.root, qualname): 459 break 460 i += 1 461 self.tensor_attrs[a] = qualname 462 self.root.register_buffer(qualname, a) 463 464 return self.create_node("get_attr", qualname, (), {}) 465 466 # higher-order operator 467 if isinstance(a, fx.GraphModule): 468 if a not in self.submodules: 469 name_submodule = f"submodule_{len(self.submodules)}" 470 self.root.add_module(name_submodule, a) 471 self.submodules[a] = name_submodule 472 return self.create_node("get_attr", self.submodules[a], (), {}) 473 474 return super().create_arg(a) # pyre-fixme[7] 475 476 @staticmethod 477 def get() -> "DispatchTracer": 478 return TRACER 479 480 def trace( # pyre-fixme[14,15] 481 self, 482 root: Callable[..., Value], 483 concrete_args: Tuple[Value, ...] = (), 484 in_spec: Optional[TreeSpec] = None, 485 ) -> Value: 486 """ 487 Traces the given graph module. 488 """ 489 with using_tracer(self): 490 return self._trace(root, concrete_args=concrete_args, in_spec=in_spec) 491 492 def _trace( 493 self, 494 root: Callable[..., Value], 495 concrete_args: Tuple[Value, ...], 496 in_spec: Optional[TreeSpec], 497 ) -> Value: 498 self.root = torch.nn.Module() 499 root_fn = root 500 501 tracer_cls = getattr(self, "__class__", None) 502 self.graph = fx.Graph(tracer_cls=tracer_cls) 503 # Don't support module, so tensor_attrs is always empty 504 self.tensor_attrs = {} 505 506 # Wrap all inputs as a PythonTensor subclass and insert them into the FX 507 # graph as placeholder nodes 508 def wrap(arg: Value, i: int) -> Value: 509 placeholder = self.create_proxy("placeholder", f"ph_{i}", (), {}) 510 if isinstance(arg, torch.Tensor): 511 return PythonTensor(arg, placeholder, is_immutable=True) 512 else: 513 # torch._assert( 514 # placeholder == arg, 515 # f"ph_{i} has been specialized to have value {arg}", 516 # ) 517 return arg 518 519 tree_args = [wrap(arg, i) for i, arg in enumerate(concrete_args)] 520 if in_spec: 521 tree_args = pytree.tree_unflatten(tree_args, in_spec) 522 523 tree_out = root_fn(*tree_args) 524 525 out_args, _ = pytree.tree_flatten(tree_out) 526 527 def unwrap(out: LeafValue) -> Union[LeafValue, torch.fx.Proxy]: 528 # it's legit for a model to return a list of items some of which 529 # are None 530 if out is None: 531 return None 532 if not isinstance(out, torch.Tensor): 533 raise TypeError( 534 f"Expect model to return torch.Tensor, got type: '{type(out)}' (value: {out})." 535 ) 536 return unwrap_proxy(out) 537 538 returns = [unwrap(out) for out in out_args] 539 540 return_annotation = None 541 # some ops like torch.sub doesn't have annotations 542 if hasattr(root_fn, "__annotations__"): 543 return_annotation = root_fn.__annotations__.get("return", None) 544 545 self.create_proxy( 546 "output", 547 "output", 548 (returns,), 549 {}, 550 type_expr=return_annotation, 551 ) 552 553 self.submodule_paths = None 554 555 return tree_out 556 557 558TRACER: Optional[DispatchTracer] = None 559TORCHDYNAMO_ENABLED: bool = False 560 561 562@contextmanager 563def using_dynamo(val: bool) -> Generator[None, None, None]: 564 global TORCHDYNAMO_ENABLED 565 TORCHDYNAMO_ENABLED, prev = val, TORCHDYNAMO_ENABLED 566 try: 567 yield 568 finally: 569 TORCHDYNAMO_ENABLED = prev 570 571 572def flattened_dispatch_trace( 573 f: Callable[..., Value], 574 args: Tuple[LeafValue, ...], 575 guards: Set[Guard], 576 in_spec: Optional[TreeSpec] = None, 577 enable_functionalization: bool = True, 578) -> Tuple[torch.fx.GraphModule, Value]: 579 if not isinstance(args, tuple): 580 raise TypeError(f"Expecting 'args' to be a tuple, got: {type(args)}") 581 582 tracer = DispatchTracer() 583 584 if enable_functionalization: 585 f = functionalize(f, remove="mutations_and_views") 586 tree_out = tracer.trace(f, concrete_args=args, in_spec=in_spec) 587 588 name = type(f).__name__ if isinstance(f, torch.nn.Module) else f.__name__ 589 gm = torch.fx.GraphModule(tracer.root, tracer.graph, name) 590 591 return (gm, tree_out) 592 593 594@dataclass 595class ExirDynamoConfig: 596 """ 597 Manage Exir-specific configurations of Dynamo. 598 """ 599 600 allow_rnn: bool = True 601 verbose: bool = True 602 assume_static_by_default: bool = False 603 604 605def flatten_output(gm: torch.fx.GraphModule) -> None: 606 """ 607 Modifies the output nodes in the submodules to return the result 608 as a flattened list. This keeps it consistent with the result of 609 EXIR's tracer 610 """ 611 for node in reversed(gm.graph.nodes): 612 if node.op == "output": 613 assert len(node.args) == 1 614 outputs = node.args[0] 615 returns, _ = pytree.tree_flatten(outputs) 616 node.args = (returns,) 617 return 618 raise RuntimeError(f"Could not find an output node in {gm.graph}") 619 620 621def _default_decomposition_table( 622 _use_old_decomp_table=False, 623) -> Dict[torch._ops.OpOverload, Callable[..., Value]]: 624 if _use_old_decomp_table: 625 decomp_opset = [ 626 torch.ops.aten.log_sigmoid_forward, 627 torch.ops.aten.ones, 628 torch.ops.aten.arange.default, 629 torch.ops.aten.arange.start, 630 torch.ops.aten.transpose, 631 ] 632 # pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.e... 633 return get_decompositions(decomp_opset) 634 # pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.exir.... 635 return default_decompositions() 636 637 638def dynamo_trace( 639 f: Callable[..., Value], 640 # pyre-ignore 641 args: Tuple[Any, ...], 642 aten_graph: bool, 643 tracing_mode: str = "real", 644 dynamo_config: Optional[ExirDynamoConfig] = None, 645 # pyre-ignore 646 dynamic_shapes: Optional[List[Any]] = None, 647 _use_old_decomp_table: bool = False, 648) -> Tuple[torch.fx.GraphModule, Set[Guard]]: 649 """ 650 TODO: Once we fully migrate to torchdynamo frontend, we will remove 651 this config option alltogether. For now, it helps with quick 652 experiments with playing around with TorchDynamo 653 """ 654 if dynamo_config is None: 655 dynamo_config = ExirDynamoConfig() 656 657 with torchdynamo.config.patch( 658 asdict(dynamo_config) 659 ), setting_python_recursive_limit(2000): 660 torchdynamo.reset() 661 try: 662 # TODO merge executorch functionalization with official 663 # functionalization 664 # pyre-fixme[7]: Expected `Tuple[GraphModule, Set[Guard]]` but got 665 # `ExportResult`. 666 return torchdynamo.export( 667 f, 668 aten_graph=aten_graph, 669 tracing_mode=tracing_mode, 670 assume_static_by_default=dynamo_config.assume_static_by_default, 671 decomposition_table=( 672 _default_decomposition_table(_use_old_decomp_table) 673 if aten_graph 674 else None 675 ), 676 dynamic_shapes=dynamic_shapes, 677 )( 678 *copy.deepcopy(args), 679 ) 680 except torchdynamo.exc.Unsupported as exc: 681 raise ExportError( 682 ExportErrorType.NOT_SUPPORTED, 683 "The user code is using a feature we don't support. " 684 "Please try torchdynamo.explain() to get possible the reasons", 685 ) from exc 686 except Exception as exc: 687 raise InternalError( 688 "torchdynamo internal error occured. Please see above stacktrace" 689 ) from exc 690 691 692def dispatch_trace( 693 f: Callable[..., Value], 694 args: Tuple[Value, ...], 695) -> torch.fx.GraphModule: 696 """ 697 Executes a given callable `f` with a given tuple of arguments. During 698 execution, Tensor operations are recorded in a fx.GraphModule, which is then 699 returned. 700 701 Args: 702 f: A `nn.Module` or a Python function that implements an ML program. 703 args: A tuple of arguments of any type to be used as inputs for the tracing run. 704 705 Returns: 706 EXIR contained in a fx.GraphModule 707 """ 708 trace_func = f 709 guards = set() 710 if TORCHDYNAMO_ENABLED: 711 # Copying args is safer in case downstream implementations of trace_func mutate them 712 trace_func, guards = dynamo_trace(trace_func, args, False) 713 714 # Copying args is safer in case downstream implementations of trace_func mutate them 715 trace_args, in_spec = pytree.tree_flatten(args) 716 717 in_args = copy.deepcopy(tuple(trace_args)) 718 gm, tree_out = flattened_dispatch_trace( 719 trace_func, 720 in_args, 721 guards, 722 in_spec, 723 enable_functionalization=False, 724 ) 725 726 _, out_spec = pytree.tree_flatten(tree_out) 727 728 # pyre-fixme[16]: `GraphModule` has no attribute `in_spec`. 729 gm.in_spec = in_spec 730 # pyre-fixme[16]: `GraphModule` has no attribute `out_spec`. 731 gm.out_spec = out_spec 732 733 # TODO (tmanlaibaatar) This is bit clowny, but our 734 # dispatch_trace sometimes creates unused node that 735 # breaks functionalization. it seems too much trouble 736 # to fix it properly since dispatch_trace will be deprecated soon. 737 # Basically dispatch_trace struggles on: 738 # def f(x: torch.Tensor) -> torch.Tensor: 739 # return torch.ones(6, dtype=x.dtype) 740 changed = gm.graph.eliminate_dead_code() 741 if changed: 742 gm.recompile() 743 744 in_args = copy.deepcopy(tuple(trace_args)) 745 assert callable(gm) 746 747 # This wrapper is used for preserving the stacktrace 748 # during second round of tracing. 749 # pyre-ignore 750 def graph_with_interpreter(*args): 751 try: 752 args = fx_pytree.tree_flatten_spec(args, gm.in_spec) # type: ignore[assignment] 753 except Exception: 754 _, received_spec = pytree.tree_flatten(args) 755 raise RuntimeError( 756 "Trying to flatten user inputs with exported input tree spec: \n" 757 f"{gm.in_spec}\n" 758 "but actually got inputs with tree spec of: \n" 759 f"{received_spec}" 760 ) 761 with torch.fx.traceback.preserve_node_meta(): 762 res = gm(*args) 763 764 if gm.out_spec is not None: 765 try: 766 res = pytree.tree_unflatten(res, gm.out_spec) 767 except Exception: 768 _, received_spec = pytree.tree_flatten(res) 769 raise RuntimeError( 770 "Trying to flatten user outputs with exported output tree spec: \n" 771 f"{gm.out_spec}\n" 772 "but actually got outputs with tree spec of: \n" 773 f"{received_spec}" 774 ) 775 return res 776 777 gm, tree_out = flattened_dispatch_trace( 778 graph_with_interpreter, 779 in_args, 780 guards, 781 in_spec, 782 enable_functionalization=True, 783 ) 784 785 gm.in_spec = in_spec 786 gm.out_spec = out_spec 787 788 return gm 789