1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9import operator 10import traceback 11from contextlib import nullcontext 12from typing import ( 13 Any, 14 Callable, 15 Dict, 16 List, 17 MutableMapping, 18 Optional, 19 Protocol, 20 runtime_checkable, 21 Set, 22 Tuple, 23 TypeVar, 24 Union, 25) 26 27import torch 28from executorch.exir import memory 29 30from executorch.exir.delegate import executorch_call_delegate, is_lowered_module 31 32from executorch.exir.dialects.edge._ops import EdgeOpOverload 33from executorch.exir.error import ExportError, ExportErrorType 34from torch import fx 35from torch._dispatch.python import enable_python_dispatcher 36from torch._subclasses import FakeTensorMode, UnsupportedFakeTensorException 37from torch._subclasses.fake_tensor import FakeTensor 38from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode 39from torch.fx import traceback as fx_traceback 40from torch.fx.experimental.proxy_tensor import PythonKeyTracer 41from torch.fx.graph import CodeGen 42from torch.fx.passes.infra.pass_base import PassBase, PassResult 43from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata 44from torch.utils import _pytree as pytree 45from torch.utils._pytree import PyTree 46 47Fn = Callable[..., Any] # pyre-ignore 48Argument = Any # pyre-ignore 49Value = Any # pyre-ignore 50NodeMetadataValue = Any # pyre-ignore 51K = TypeVar("K") 52PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]] 53 54 55_TORCH_SYM_OPS: Set[Any] = { # pyre-ignore 56 torch.sym_int, 57 torch.sym_float, 58 torch.sym_ite, 59 torch.sym_max, 60 torch.sym_min, 61 torch.sym_not, 62 torch.sym_sqrt, 63} 64 65 66PROTECTED_KEYS: Set[str] = { 67 "val", 68 "stack_trace", 69 "nn_module_stack", 70 "debug_handle", 71 "tensor_meta", 72} 73 74 75def _unstack_pytree(xs) -> List[PyTree]: # pyre-ignore 76 flat_xs, inspec = pytree.tree_flatten(xs) 77 if not all(isinstance(xs, torch.Tensor) for xs in flat_xs): 78 raise RuntimeError(f"Leaves of xs must be Tensor {flat_xs}") 79 80 if not all(xs.shape[0] == flat_xs[0].shape[0] for xs in flat_xs): 81 raise RuntimeError( 82 f"Leaves of xs must have same leading dimension size {[xs.shape for xs in flat_xs]}" 83 ) 84 85 ctx = ( 86 FunctionalTensorMode 87 if any(isinstance(x, FunctionalTensor) for x in flat_xs) 88 else nullcontext 89 ) 90 with ctx(): 91 a = zip(*flat_xs) 92 93 pytrees = [] 94 for tuple in a: 95 pytrees.append(pytree.tree_unflatten(tuple, inspec)) 96 return pytrees 97 98 99class NodeMetadata: 100 def __init__(self, data: Dict[str, Any]) -> None: 101 self.data: Dict[str, Any] = data.copy() 102 103 def __getitem__(self, key: str) -> NodeMetadataValue: 104 return self.data[key] 105 106 def __setitem__(self, key: str, value: NodeMetadataValue) -> NodeMetadataValue: 107 if key in PROTECTED_KEYS: 108 raise RuntimeError(f"Could not override node key: {key}") 109 self.data[key] = value 110 111 def __contains__(self, key: str) -> bool: 112 return key in self.data 113 114 def copy(self) -> "NodeMetadata": 115 return NodeMetadata(self.data.copy()) 116 117 118class ProxyValue: 119 # pyre-ignore 120 def __init__(self, data, proxy: Union[torch.fx.Proxy, torch.fx.Node]): 121 # pyre-ignore 122 self.data = data 123 self.proxy_or_node = proxy 124 125 @property 126 def node(self) -> torch.fx.Node: 127 if isinstance(self.proxy_or_node, torch.fx.Node): 128 return self.proxy_or_node 129 assert isinstance(self.proxy_or_node, torch.fx.Proxy) 130 return self.proxy_or_node.node 131 132 @property 133 def proxy(self) -> torch.fx.Proxy: 134 if not isinstance(self.proxy_or_node, torch.fx.Proxy): 135 raise RuntimeError( 136 f"ProxyValue doesn't have attached Proxy object. Node: {self.proxy_or_node.format_node()}" 137 ) 138 return self.proxy_or_node 139 140 def to_tensor(self) -> torch.Tensor: 141 assert isinstance(self.data, torch.Tensor) 142 return self.data 143 144 def is_tensor(self) -> bool: 145 return isinstance(self.data, torch.Tensor) 146 147 # pyre-ignore 148 def __iter__(self): 149 yield from self.data 150 151 def __bool__(self) -> bool: 152 return bool(self.data) 153 154 155class ExportPassBaseError(RuntimeError): 156 pass 157 158 159class _ExportPassBase(PassBase): 160 """ 161 Interpreter-based pass class to help users maintain the IR spec while writing 162 transformations. 163 """ 164 165 @staticmethod 166 def _create_dummy_node_metadata() -> NodeMetadata: 167 return NodeMetadata({"stack_trace": "".join(traceback.format_stack(limit=1))}) 168 169 class ExportTracer(PythonKeyTracer): 170 def __init__(self, callback: "_ExportPassBase", codegen: CodeGen) -> None: 171 super().__init__() 172 self.callback = callback 173 self.root = torch.nn.Module() 174 self.graph = torch.fx.Graph() 175 self.graph.set_codegen(codegen) 176 self.tensor_attrs: Dict[str, torch.Tensor] = {} # type: ignore[assignment] 177 self.fake_tensor_mode: Optional[FakeTensorMode] = None 178 self.submodules: Dict[torch.nn.Module, str] = {} 179 180 def trace(self) -> None: # pyre-fixme[14,15] 181 raise ExportPassBaseError("ExportTracer doesn't support trace().") 182 183 def create_arg(self, a: Argument) -> torch.fx.Node: 184 if isinstance(a, torch.nn.Module): 185 if a not in self.submodules: 186 name_submodule = f"submodule_{len(self.submodules)}" 187 self.root.add_module(name_submodule, a) 188 self.submodules[a] = name_submodule 189 elif isinstance(a, FakeTensor): 190 if not hasattr(a, "constant") or a.constant is None: 191 raise ExportPassBaseError(f"Cannot add {a} to graph.") 192 a = a.constant 193 node = super().create_arg(a) 194 if ( 195 isinstance(a, torch.Tensor) 196 and isinstance(node, torch.fx.Node) 197 and node.op == "get_attr" 198 ): 199 self.set_metadata(node, a) 200 self.callback.on_attr(ProxyValue(a, node)) 201 return node 202 203 def set_metadata( # noqa: C901 204 self, 205 node: torch.fx.Node, 206 value: Argument, 207 ) -> None: 208 # propagate the fake tensor or sym nodes 209 def make_val( 210 x: Argument, 211 ) -> Union[ 212 FakeTensor, 213 torch.SymInt, 214 torch.SymFloat, 215 torch.SymBool, 216 int, 217 float, 218 bool, 219 str, 220 None, 221 ]: 222 if isinstance(x, FakeTensor): 223 return x 224 elif isinstance(x, torch.Tensor): 225 if x.is_quantized: 226 # TODO (tmanlaibaatar) properly support Quantized FakeTensor 227 x = torch.dequantize(x) 228 229 try: 230 assert self.fake_tensor_mode is not None 231 # TODO we should allocate static shapes 232 # for param/buffer values 233 if isinstance(x, torch.nn.Parameter): 234 fake_tensor = self.fake_tensor_mode.from_tensor( 235 x, static_shapes=True 236 ) 237 else: 238 fake_tensor = self.fake_tensor_mode.from_tensor(x) 239 except UnsupportedFakeTensorException: 240 # TODO: This is just a workaround to get over the 241 # x.as_subclass error 242 print( 243 "Fakeifying a Tensor subclass is not supported \ 244 right now. Instead a TensorMetadata is used." 245 ) 246 fake_tensor = None 247 return fake_tensor 248 elif isinstance( 249 x, 250 ( 251 torch.SymInt, 252 torch.SymFloat, 253 torch.SymBool, 254 int, 255 float, 256 bool, 257 str, 258 ), 259 ): 260 return x 261 else: 262 return None 263 264 node.meta["val"] = pytree.tree_map(make_val, value) 265 266 # Set the tensor_metadata for values that do not have a corresponding FakeTensor 267 def make_tensor_meta(x: Argument) -> Optional[TensorMetadata]: 268 if not isinstance(x, FakeTensor) and isinstance(x, torch.Tensor): 269 if x.is_quantized: 270 # TODO (tmanlaibaatar) properly support Quantized FakeTensor 271 x = torch.dequantize(x) 272 273 try: 274 assert self.fake_tensor_mode is not None 275 _ = self.fake_tensor_mode.from_tensor(x) 276 tensor_meta = None 277 except UnsupportedFakeTensorException: 278 # TODO: This is just a workaround to get over the 279 # x.as_subclass error 280 tensor_meta = _extract_tensor_metadata(x) 281 return tensor_meta 282 else: 283 return None 284 285 node.meta["tensor_meta"] = pytree.tree_map(make_tensor_meta, value) 286 287 class ExportInterpreter(fx.Interpreter): 288 def __init__(self, callback: "_ExportPassBase", gm: fx.GraphModule) -> None: 289 super().__init__(gm) 290 self.callback = callback 291 self.node: torch.fx.Node = next(iter(gm.graph.nodes)) 292 293 def placeholder( # pyre-fixme[14] 294 self, 295 target: str, 296 args: Tuple[Argument, ...], 297 kwargs: Dict[str, Argument], 298 ) -> ProxyValue: 299 arg = super().placeholder(target, args, kwargs) 300 return self.callback.placeholder(target, arg, NodeMetadata(self.node.meta)) 301 302 def output( 303 self, 304 target: torch.fx.node.Target, 305 args: Tuple[Argument, ...], 306 kwargs: Dict[str, Argument], 307 ) -> ProxyValue: 308 return self.callback.output(args[0], NodeMetadata(self.node.meta)).data 309 310 def call_function( 311 self, 312 target: torch.fx.node.Target, 313 args: Tuple[Argument, ...], 314 kwargs: Dict[str, Argument], 315 ) -> ProxyValue: 316 meta = NodeMetadata(self.node.meta) 317 318 if target == operator.getitem: 319 value, key = args 320 return self.callback.call_getitem(value, key, meta) 321 elif getattr(target, "__module__", None) in { 322 "_operator", 323 "builtins", 324 "math", 325 }: 326 assert callable(target) 327 return self.callback.call_sym(target, args, meta) 328 elif target in _TORCH_SYM_OPS: 329 assert callable(target) 330 return self.callback.call_sym(target, args, meta) 331 elif isinstance( 332 target, (torch._ops.OpOverload, torch._ops.OpOverloadPacket) 333 ): 334 return self.callback.call_operator( 335 target, 336 args, 337 kwargs, 338 meta, 339 ) 340 elif target == torch.ops.higher_order.cond: 341 pred, true_fn, false_fn, inputs = args 342 return self.callback.call_cond(pred, true_fn, false_fn, inputs, meta) 343 elif target == torch.ops.higher_order.map_impl: 344 f, mapped_args, operands = args # type: ignore[assignment] 345 return self.callback.call_map(f, mapped_args, operands, meta) 346 # For other unregistered HigherOrderOps, just interpret them blindly 347 elif isinstance(target, torch._ops.HigherOrderOperator): 348 return self.callback._fx( 349 "call_function", 350 target, 351 args, 352 kwargs, 353 meta, 354 ) 355 else: 356 raise ExportPassBaseError(f"Unsupported target type: {target}") 357 358 def get_attr( # pyre-fixme[14] 359 self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument] 360 ) -> Argument: 361 return super().get_attr(target, args, kwargs) 362 363 def call_module( 364 self, 365 target: torch.fx.node.Target, 366 args: Tuple[Argument, ...], 367 kwargs: Dict[str, Argument], 368 ) -> None: 369 raise ExportPassBaseError("call_module is not supported.") 370 371 def call_method( # pyre-fixme[14] 372 self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument] 373 ) -> None: 374 raise ExportPassBaseError("call_method is not supported.") 375 376 def run_node(self, n: torch.fx.Node) -> Argument: 377 self.node = n 378 self.callback.node_debug_str = n.format_node() 379 return super().run_node(n) 380 381 def __init__(self) -> None: 382 self.interpreter = torch.fx.Interpreter( 383 torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph()) 384 ) 385 self.tracer = self.ExportTracer(self, CodeGen()) # pyre-ignore 386 self.fake_tensor_mode: Optional[FakeTensorMode] = None 387 self._initialized = True 388 self.node_debug_str: Optional[str] = None 389 390 def _fx( 391 self, 392 kind: str, 393 target: torch.fx.node.Target, 394 args: Tuple[Argument, ...], 395 kwargs: Dict[str, Argument], 396 meta: NodeMetadata, 397 ) -> ProxyValue: 398 args_data, kwargs_data = pytree.tree_map_only( 399 ProxyValue, lambda x: x.data, (args, kwargs) 400 ) 401 res_data = getattr(self.interpreter, kind)(target, args_data, kwargs_data) 402 args_proxy, kwargs_proxy = pytree.tree_map_only( 403 ProxyValue, lambda x: x.proxy, (args, kwargs) 404 ) 405 406 name = None 407 if isinstance(target, torch._ops.OpOverload): 408 name = self.tracer.graph._target_to_str(target.overloadpacket.__name__) 409 410 res_proxy = self.tracer.create_proxy( 411 kind, target, args_proxy, kwargs_proxy, name=name 412 ) 413 res_proxy.node.meta.update(meta.data) 414 self.tracer.set_metadata(res_proxy.node, res_data) 415 return ProxyValue(res_data, res_proxy) 416 417 def inputs(self, graph_module: torch.fx.GraphModule) -> List[Argument]: 418 # TODO(angelayi): Update this with what we decide to do for metadata in 419 # the exported graph module 420 if (args := graph_module.meta.get("args", None)) is not None: 421 return list(args) 422 423 def extract_input(node: torch.fx.Node) -> Optional[FakeTensor]: 424 if "val" in node.meta: 425 fake = node.meta["val"] 426 if hasattr(fake, "constant") and fake.constant is not None: 427 return fake.constant 428 return fake 429 elif tensor_meta := node.meta.get("tensor_meta"): 430 assert self.fake_tensor_mode is not None 431 return FakeTensor( 432 self.fake_tensor_mode, 433 torch.empty( 434 tensor_meta.shape, 435 dtype=tensor_meta.dtype, 436 device="meta", 437 requires_grad=tensor_meta.requires_grad, 438 memory_format=tensor_meta.memory_format, 439 ), 440 torch.device("cpu"), 441 ) 442 elif len(node.users) == 0: 443 return None 444 raise ExportPassBaseError( 445 f"Cannot construct an input for graph module: {graph_module}.", 446 ) 447 448 return [ 449 extract_input(node) 450 for node in graph_module.graph.nodes 451 if node.op == "placeholder" 452 ] 453 454 def on_attr(self, attr: ProxyValue) -> None: 455 pass 456 457 def placeholder(self, name: str, arg: Argument, meta: NodeMetadata) -> ProxyValue: 458 arg_proxy = self.tracer.create_proxy("placeholder", name, (), {}) 459 arg_proxy.node.meta = meta.data 460 arg_proxy.node.meta["val"] = arg 461 return ProxyValue(arg, arg_proxy) 462 463 def call_operator( 464 self, 465 op, # pyre-ignore 466 args: Tuple[Argument, ...], 467 kwargs: Dict[str, Argument], 468 meta: NodeMetadata, 469 ) -> ProxyValue: 470 return self._fx("call_function", op, args, kwargs, meta) 471 472 def call_sym( 473 self, 474 target: Fn, 475 args: Tuple[Argument, ...], 476 meta: NodeMetadata, 477 ) -> ProxyValue: 478 return self._fx("call_function", target, args, {}, meta) 479 480 def call_cond( 481 self, 482 pred: ProxyValue, 483 true_fn: torch.fx.GraphModule, 484 false_fn: torch.fx.GraphModule, 485 inputs: List[Argument], 486 meta: NodeMetadata, 487 ) -> ProxyValue: 488 true_branch = self.call_submodule(true_fn, tuple(inputs)) 489 false_branch = self.call_submodule(false_fn, tuple(inputs)) 490 assert true_branch is not None 491 assert false_branch is not None 492 return self._fx( 493 "call_function", 494 torch.ops.higher_order.cond, 495 (pred, true_branch.graph_module, false_branch.graph_module, list(inputs)), 496 {}, 497 meta, 498 ) 499 500 def call_map( 501 self, 502 f: torch.fx.GraphModule, 503 mapped_args: List[ProxyValue], 504 operands: List[ProxyValue], 505 meta: NodeMetadata, 506 ) -> ProxyValue: 507 xs = _unstack_pytree([arg.data for arg in mapped_args])[0] 508 f_branch = self.call_submodule(f, tuple(xs + [arg.data for arg in operands])) 509 assert f_branch is not None 510 return self._fx( 511 "call_function", 512 torch.ops.higher_order.map_impl, 513 (f_branch.graph_module, mapped_args, operands), 514 {}, 515 meta, 516 ) 517 518 def call_getitem( 519 self, value: ProxyValue, key: int, meta: NodeMetadata 520 ) -> ProxyValue: 521 return self._fx("call_function", operator.getitem, (value, key), {}, meta) 522 523 def output(self, results: List[Argument], meta: NodeMetadata) -> ProxyValue: 524 return self._fx("output", "output", (results,), {}, meta) 525 526 def call_submodule( 527 self, graph_module: fx.GraphModule, inputs: Tuple[Argument, ...] 528 ) -> PassResult: 529 prev_tracer, self.tracer = self.tracer, self.ExportTracer( 530 self, graph_module.graph._codegen 531 ) 532 self.tracer.fake_tensor_mode = prev_tracer.fake_tensor_mode 533 interpreter = self.ExportInterpreter(self, graph_module) 534 prev_interpreter, self.interpreter = self.interpreter, torch.fx.Interpreter( 535 torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph()) 536 ) 537 inputs_data = pytree.tree_map_only(ProxyValue, lambda x: x.data, inputs) 538 with fx_traceback.preserve_node_meta(): 539 interpreter.run(*inputs_data) 540 541 new_graph_module = torch.fx.GraphModule(self.tracer.root, self.tracer.graph) 542 543 self.tracer = prev_tracer 544 self.interpreter = prev_interpreter 545 return PassResult( 546 new_graph_module, 547 True, 548 ) 549 550 def call(self, graph_module: fx.GraphModule) -> PassResult: 551 if not getattr(self, "_initialized", False): 552 raise ExportPassBaseError( 553 "ExportPass is not initialized with __init__().", 554 ) 555 556 inputs = self.inputs(graph_module) 557 558 fake_tensor_mode = None 559 for i in inputs: 560 if isinstance(i, FakeTensor): 561 assert ( 562 fake_tensor_mode is None or fake_tensor_mode is i.fake_mode 563 ), "Multiple fake tensor mode detected." 564 fake_tensor_mode = i.fake_mode 565 if fake_tensor_mode is None: 566 self.tracer.fake_tensor_mode = FakeTensorMode(allow_non_fake_inputs=True) 567 fake_tensor_mode = nullcontext() # type: ignore[assignment] 568 dispatcher_mode = nullcontext() # type: ignore[assignment] 569 else: 570 fake_tensor_mode.allow_non_fake_inputs = True 571 self.tracer.fake_tensor_mode = fake_tensor_mode 572 dispatcher_mode = enable_python_dispatcher() # type: ignore[assignment] 573 self.fake_tensor_mode = self.tracer.fake_tensor_mode 574 575 with fake_tensor_mode, dispatcher_mode: # type: ignore[assignment, union-attr] 576 result = self.call_submodule(graph_module, tuple(inputs)) 577 578 return result 579 580 581class ExportPass(_ExportPassBase): 582 class ExportTracer(_ExportPassBase.ExportTracer): 583 def create_arg(self, a: Argument) -> torch.fx.Node: 584 if isinstance(a, torch.nn.Module): 585 if a not in self.submodules: 586 prefix = "lowered_module" if is_lowered_module(a) else "submodule" 587 name_submodule = f"{prefix}_{len(self.submodules)}" 588 self.root.add_module(name_submodule, a) 589 self.submodules[a] = name_submodule 590 return super().create_arg(a) 591 592 class ExportInterpreter(_ExportPassBase.ExportInterpreter): 593 """ 594 Interpreter to callback on any ExportPassBase functions 595 """ 596 597 def __init__(self, callback: "ExportPass", gm: fx.GraphModule) -> None: 598 super().__init__(callback, gm) 599 600 def call_function( 601 self, 602 target: torch.fx.node.Target, 603 args: Tuple[Argument, ...], 604 kwargs: Dict[str, Argument], 605 ) -> ProxyValue: 606 meta = NodeMetadata(self.node.meta) 607 if target == operator.getitem: 608 value, key = args 609 return self.callback.call_getitem(value, key, meta) 610 elif isinstance(target, EdgeOpOverload): 611 return self.callback.call_operator( 612 target, 613 args, 614 kwargs, 615 meta, 616 ) 617 618 # TODO according to zhengxu ExportPassBase should not be aware of 619 # memory.alloc. Check this comment: 620 # https://www.internalfb.com/diff/D42758019?dst_version_fbid=5906016402813292&transaction_fbid=1104713900200176 621 elif target == memory.alloc: 622 return self.callback._fx( 623 "call_function", 624 target, 625 args, 626 kwargs, 627 meta, 628 ) 629 630 elif target == executorch_call_delegate: 631 lowered_module = args[0] 632 args = args[1:] 633 return self.callback.call_delegate( # pyre-ignore 634 lowered_module, 635 args, 636 kwargs, 637 NodeMetadata(self.node.meta), 638 ) 639 640 return super().call_function(target, args, kwargs) 641 642 def call_delegate( 643 self, 644 # pyre-ignore: Undefined or invalid type [11]: Annotation `LoweredBackendModule` is not defined as a type. 645 lowered_module: "LoweredBackendModule", # noqa 646 args: Tuple[ProxyValue, ...], 647 kwargs: Dict[str, Argument], 648 meta: NodeMetadata, 649 ) -> ProxyValue: 650 args = (lowered_module,) + args 651 return self._fx( 652 "call_function", 653 executorch_call_delegate, 654 args, 655 kwargs, 656 meta, 657 ) 658 659 def call_submodule( 660 self, graph_module: fx.GraphModule, inputs: Tuple[Argument, ...] 661 ) -> PassResult: 662 res = super().call_submodule(graph_module, inputs) 663 664 def preserve_original_ph_meta_val( 665 gm: torch.fx.GraphModule, new_gm: torch.fx.GraphModule 666 ) -> None: 667 def get_phs(gm: torch.fx.GraphModule) -> List[torch.fx.Node]: 668 return [node for node in gm.graph.nodes if node.op == "placeholder"] 669 670 def migrate_meta_val( 671 orig_phs: List[torch.fx.Node], new_phs: List[torch.fx.Node] 672 ) -> None: 673 if len(orig_phs) != len(new_phs): 674 raise ExportError( 675 ExportErrorType.NOT_SUPPORTED, 676 "ExportPassBase doesn't support changing the placeholders", 677 ) 678 for ph, new_ph in zip(orig_phs, new_phs): 679 if isinstance(new_ph.meta["val"], torch.Tensor): 680 if ( 681 not isinstance(ph.meta["val"], torch.Tensor) 682 or new_ph.meta["val"].size() != ph.meta["val"].size() 683 ): 684 raise ExportError( 685 ExportErrorType.NOT_SUPPORTED, 686 "ExportPassBase doesn't support changing the placeholders", 687 ) 688 new_ph.meta["val"] = ph.meta["val"] 689 690 migrate_meta_val(get_phs(gm), get_phs(new_gm)) 691 692 # After one pass, new_graph_module's placeholders will always hold fake tensors in 693 # meta['val'] but sometimes we want to preserve the original meta['val'] of placeholders 694 # 695 # For example, custom flows and certain passes assume no fake_tensor_mode is activated 696 # and it doesn't quite work with fake_tensor_mode. but we don't bother to fix them. 697 # So we'll just reset the meta of placeholders to its original value. It's safe because that 698 # 1. For models captured with pt2_mode, the meta['val'] of placeholders are fake_tensors already, so 699 # preserving it to the new graph module won't hurt. 700 # 2. For models captured with dispatch_trace, the meta['val'] field 701 # Note that it's only safe when passes don't modify the inputs. 702 preserve_original_ph_meta_val(graph_module, res.graph_module) 703 704 return res 705 706 707@runtime_checkable 708class ArgSchema(Protocol): 709 name: str 710 kwarg_only: bool 711 type: Any # pyre-ignore 712 713 714def map_args( 715 op: torch._ops.OpOverload, 716 fn: Fn, 717 args: Argument, 718 kwargs: Dict[str, Argument], 719) -> Tuple[Argument, Dict[str, Argument]]: 720 assert isinstance(args, tuple) 721 assert isinstance(kwargs, dict) 722 args = list(args) 723 kwargs = kwargs.copy() 724 725 def update(key: K, args: MutableMapping[K, PyTree], schema: ArgSchema) -> None: 726 args[key] = fn(args[key], schema) 727 728 for i, schema in enumerate(op._schema.arguments): 729 if schema.name in kwargs: 730 update(schema.name, kwargs, schema) 731 elif not schema.kwarg_only and i < len(args): 732 update(i, args, schema) # pyre-ignore 733 734 return tuple(args), kwargs 735