xref: /aosp_15_r20/external/executorch/exir/pass_base.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9import operator
10import traceback
11from contextlib import nullcontext
12from typing import (
13    Any,
14    Callable,
15    Dict,
16    List,
17    MutableMapping,
18    Optional,
19    Protocol,
20    runtime_checkable,
21    Set,
22    Tuple,
23    TypeVar,
24    Union,
25)
26
27import torch
28from executorch.exir import memory
29
30from executorch.exir.delegate import executorch_call_delegate, is_lowered_module
31
32from executorch.exir.dialects.edge._ops import EdgeOpOverload
33from executorch.exir.error import ExportError, ExportErrorType
34from torch import fx
35from torch._dispatch.python import enable_python_dispatcher
36from torch._subclasses import FakeTensorMode, UnsupportedFakeTensorException
37from torch._subclasses.fake_tensor import FakeTensor
38from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode
39from torch.fx import traceback as fx_traceback
40from torch.fx.experimental.proxy_tensor import PythonKeyTracer
41from torch.fx.graph import CodeGen
42from torch.fx.passes.infra.pass_base import PassBase, PassResult
43from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
44from torch.utils import _pytree as pytree
45from torch.utils._pytree import PyTree
46
47Fn = Callable[..., Any]  # pyre-ignore
48Argument = Any  # pyre-ignore
49Value = Any  # pyre-ignore
50NodeMetadataValue = Any  # pyre-ignore
51K = TypeVar("K")
52PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]
53
54
55_TORCH_SYM_OPS: Set[Any] = {  # pyre-ignore
56    torch.sym_int,
57    torch.sym_float,
58    torch.sym_ite,
59    torch.sym_max,
60    torch.sym_min,
61    torch.sym_not,
62    torch.sym_sqrt,
63}
64
65
66PROTECTED_KEYS: Set[str] = {
67    "val",
68    "stack_trace",
69    "nn_module_stack",
70    "debug_handle",
71    "tensor_meta",
72}
73
74
75def _unstack_pytree(xs) -> List[PyTree]:  # pyre-ignore
76    flat_xs, inspec = pytree.tree_flatten(xs)
77    if not all(isinstance(xs, torch.Tensor) for xs in flat_xs):
78        raise RuntimeError(f"Leaves of xs must be Tensor {flat_xs}")
79
80    if not all(xs.shape[0] == flat_xs[0].shape[0] for xs in flat_xs):
81        raise RuntimeError(
82            f"Leaves of xs must have same leading dimension size {[xs.shape for xs in flat_xs]}"
83        )
84
85    ctx = (
86        FunctionalTensorMode
87        if any(isinstance(x, FunctionalTensor) for x in flat_xs)
88        else nullcontext
89    )
90    with ctx():
91        a = zip(*flat_xs)
92
93    pytrees = []
94    for tuple in a:
95        pytrees.append(pytree.tree_unflatten(tuple, inspec))
96    return pytrees
97
98
99class NodeMetadata:
100    def __init__(self, data: Dict[str, Any]) -> None:
101        self.data: Dict[str, Any] = data.copy()
102
103    def __getitem__(self, key: str) -> NodeMetadataValue:
104        return self.data[key]
105
106    def __setitem__(self, key: str, value: NodeMetadataValue) -> NodeMetadataValue:
107        if key in PROTECTED_KEYS:
108            raise RuntimeError(f"Could not override node key: {key}")
109        self.data[key] = value
110
111    def __contains__(self, key: str) -> bool:
112        return key in self.data
113
114    def copy(self) -> "NodeMetadata":
115        return NodeMetadata(self.data.copy())
116
117
118class ProxyValue:
119    # pyre-ignore
120    def __init__(self, data, proxy: Union[torch.fx.Proxy, torch.fx.Node]):
121        # pyre-ignore
122        self.data = data
123        self.proxy_or_node = proxy
124
125    @property
126    def node(self) -> torch.fx.Node:
127        if isinstance(self.proxy_or_node, torch.fx.Node):
128            return self.proxy_or_node
129        assert isinstance(self.proxy_or_node, torch.fx.Proxy)
130        return self.proxy_or_node.node
131
132    @property
133    def proxy(self) -> torch.fx.Proxy:
134        if not isinstance(self.proxy_or_node, torch.fx.Proxy):
135            raise RuntimeError(
136                f"ProxyValue doesn't have attached Proxy object. Node: {self.proxy_or_node.format_node()}"
137            )
138        return self.proxy_or_node
139
140    def to_tensor(self) -> torch.Tensor:
141        assert isinstance(self.data, torch.Tensor)
142        return self.data
143
144    def is_tensor(self) -> bool:
145        return isinstance(self.data, torch.Tensor)
146
147    # pyre-ignore
148    def __iter__(self):
149        yield from self.data
150
151    def __bool__(self) -> bool:
152        return bool(self.data)
153
154
155class ExportPassBaseError(RuntimeError):
156    pass
157
158
159class _ExportPassBase(PassBase):
160    """
161    Interpreter-based pass class to help users maintain the IR spec while writing
162    transformations.
163    """
164
165    @staticmethod
166    def _create_dummy_node_metadata() -> NodeMetadata:
167        return NodeMetadata({"stack_trace": "".join(traceback.format_stack(limit=1))})
168
169    class ExportTracer(PythonKeyTracer):
170        def __init__(self, callback: "_ExportPassBase", codegen: CodeGen) -> None:
171            super().__init__()
172            self.callback = callback
173            self.root = torch.nn.Module()
174            self.graph = torch.fx.Graph()
175            self.graph.set_codegen(codegen)
176            self.tensor_attrs: Dict[str, torch.Tensor] = {}  # type: ignore[assignment]
177            self.fake_tensor_mode: Optional[FakeTensorMode] = None
178            self.submodules: Dict[torch.nn.Module, str] = {}
179
180        def trace(self) -> None:  # pyre-fixme[14,15]
181            raise ExportPassBaseError("ExportTracer doesn't support trace().")
182
183        def create_arg(self, a: Argument) -> torch.fx.Node:
184            if isinstance(a, torch.nn.Module):
185                if a not in self.submodules:
186                    name_submodule = f"submodule_{len(self.submodules)}"
187                    self.root.add_module(name_submodule, a)
188                    self.submodules[a] = name_submodule
189            elif isinstance(a, FakeTensor):
190                if not hasattr(a, "constant") or a.constant is None:
191                    raise ExportPassBaseError(f"Cannot add {a} to graph.")
192                a = a.constant
193            node = super().create_arg(a)
194            if (
195                isinstance(a, torch.Tensor)
196                and isinstance(node, torch.fx.Node)
197                and node.op == "get_attr"
198            ):
199                self.set_metadata(node, a)
200                self.callback.on_attr(ProxyValue(a, node))
201            return node
202
203        def set_metadata(  # noqa: C901
204            self,
205            node: torch.fx.Node,
206            value: Argument,
207        ) -> None:
208            # propagate the fake tensor or sym nodes
209            def make_val(
210                x: Argument,
211            ) -> Union[
212                FakeTensor,
213                torch.SymInt,
214                torch.SymFloat,
215                torch.SymBool,
216                int,
217                float,
218                bool,
219                str,
220                None,
221            ]:
222                if isinstance(x, FakeTensor):
223                    return x
224                elif isinstance(x, torch.Tensor):
225                    if x.is_quantized:
226                        # TODO (tmanlaibaatar) properly support Quantized FakeTensor
227                        x = torch.dequantize(x)
228
229                    try:
230                        assert self.fake_tensor_mode is not None
231                        # TODO we should allocate static shapes
232                        # for param/buffer values
233                        if isinstance(x, torch.nn.Parameter):
234                            fake_tensor = self.fake_tensor_mode.from_tensor(
235                                x, static_shapes=True
236                            )
237                        else:
238                            fake_tensor = self.fake_tensor_mode.from_tensor(x)
239                    except UnsupportedFakeTensorException:
240                        # TODO: This is just a workaround to get over the
241                        # x.as_subclass error
242                        print(
243                            "Fakeifying a Tensor subclass is not supported \
244                            right now. Instead a TensorMetadata is used."
245                        )
246                        fake_tensor = None
247                    return fake_tensor
248                elif isinstance(
249                    x,
250                    (
251                        torch.SymInt,
252                        torch.SymFloat,
253                        torch.SymBool,
254                        int,
255                        float,
256                        bool,
257                        str,
258                    ),
259                ):
260                    return x
261                else:
262                    return None
263
264            node.meta["val"] = pytree.tree_map(make_val, value)
265
266            # Set the tensor_metadata for values that do not have a corresponding FakeTensor
267            def make_tensor_meta(x: Argument) -> Optional[TensorMetadata]:
268                if not isinstance(x, FakeTensor) and isinstance(x, torch.Tensor):
269                    if x.is_quantized:
270                        # TODO (tmanlaibaatar) properly support Quantized FakeTensor
271                        x = torch.dequantize(x)
272
273                    try:
274                        assert self.fake_tensor_mode is not None
275                        _ = self.fake_tensor_mode.from_tensor(x)
276                        tensor_meta = None
277                    except UnsupportedFakeTensorException:
278                        # TODO: This is just a workaround to get over the
279                        # x.as_subclass error
280                        tensor_meta = _extract_tensor_metadata(x)
281                    return tensor_meta
282                else:
283                    return None
284
285            node.meta["tensor_meta"] = pytree.tree_map(make_tensor_meta, value)
286
287    class ExportInterpreter(fx.Interpreter):
288        def __init__(self, callback: "_ExportPassBase", gm: fx.GraphModule) -> None:
289            super().__init__(gm)
290            self.callback = callback
291            self.node: torch.fx.Node = next(iter(gm.graph.nodes))
292
293        def placeholder(  # pyre-fixme[14]
294            self,
295            target: str,
296            args: Tuple[Argument, ...],
297            kwargs: Dict[str, Argument],
298        ) -> ProxyValue:
299            arg = super().placeholder(target, args, kwargs)
300            return self.callback.placeholder(target, arg, NodeMetadata(self.node.meta))
301
302        def output(
303            self,
304            target: torch.fx.node.Target,
305            args: Tuple[Argument, ...],
306            kwargs: Dict[str, Argument],
307        ) -> ProxyValue:
308            return self.callback.output(args[0], NodeMetadata(self.node.meta)).data
309
310        def call_function(
311            self,
312            target: torch.fx.node.Target,
313            args: Tuple[Argument, ...],
314            kwargs: Dict[str, Argument],
315        ) -> ProxyValue:
316            meta = NodeMetadata(self.node.meta)
317
318            if target == operator.getitem:
319                value, key = args
320                return self.callback.call_getitem(value, key, meta)
321            elif getattr(target, "__module__", None) in {
322                "_operator",
323                "builtins",
324                "math",
325            }:
326                assert callable(target)
327                return self.callback.call_sym(target, args, meta)
328            elif target in _TORCH_SYM_OPS:
329                assert callable(target)
330                return self.callback.call_sym(target, args, meta)
331            elif isinstance(
332                target, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)
333            ):
334                return self.callback.call_operator(
335                    target,
336                    args,
337                    kwargs,
338                    meta,
339                )
340            elif target == torch.ops.higher_order.cond:
341                pred, true_fn, false_fn, inputs = args
342                return self.callback.call_cond(pred, true_fn, false_fn, inputs, meta)
343            elif target == torch.ops.higher_order.map_impl:
344                f, mapped_args, operands = args  # type: ignore[assignment]
345                return self.callback.call_map(f, mapped_args, operands, meta)
346            # For other unregistered HigherOrderOps, just interpret them blindly
347            elif isinstance(target, torch._ops.HigherOrderOperator):
348                return self.callback._fx(
349                    "call_function",
350                    target,
351                    args,
352                    kwargs,
353                    meta,
354                )
355            else:
356                raise ExportPassBaseError(f"Unsupported target type: {target}")
357
358        def get_attr(  # pyre-fixme[14]
359            self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument]
360        ) -> Argument:
361            return super().get_attr(target, args, kwargs)
362
363        def call_module(
364            self,
365            target: torch.fx.node.Target,
366            args: Tuple[Argument, ...],
367            kwargs: Dict[str, Argument],
368        ) -> None:
369            raise ExportPassBaseError("call_module is not supported.")
370
371        def call_method(  # pyre-fixme[14]
372            self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument]
373        ) -> None:
374            raise ExportPassBaseError("call_method is not supported.")
375
376        def run_node(self, n: torch.fx.Node) -> Argument:
377            self.node = n
378            self.callback.node_debug_str = n.format_node()
379            return super().run_node(n)
380
381    def __init__(self) -> None:
382        self.interpreter = torch.fx.Interpreter(
383            torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
384        )
385        self.tracer = self.ExportTracer(self, CodeGen())  # pyre-ignore
386        self.fake_tensor_mode: Optional[FakeTensorMode] = None
387        self._initialized = True
388        self.node_debug_str: Optional[str] = None
389
390    def _fx(
391        self,
392        kind: str,
393        target: torch.fx.node.Target,
394        args: Tuple[Argument, ...],
395        kwargs: Dict[str, Argument],
396        meta: NodeMetadata,
397    ) -> ProxyValue:
398        args_data, kwargs_data = pytree.tree_map_only(
399            ProxyValue, lambda x: x.data, (args, kwargs)
400        )
401        res_data = getattr(self.interpreter, kind)(target, args_data, kwargs_data)
402        args_proxy, kwargs_proxy = pytree.tree_map_only(
403            ProxyValue, lambda x: x.proxy, (args, kwargs)
404        )
405
406        name = None
407        if isinstance(target, torch._ops.OpOverload):
408            name = self.tracer.graph._target_to_str(target.overloadpacket.__name__)
409
410        res_proxy = self.tracer.create_proxy(
411            kind, target, args_proxy, kwargs_proxy, name=name
412        )
413        res_proxy.node.meta.update(meta.data)
414        self.tracer.set_metadata(res_proxy.node, res_data)
415        return ProxyValue(res_data, res_proxy)
416
417    def inputs(self, graph_module: torch.fx.GraphModule) -> List[Argument]:
418        # TODO(angelayi): Update this with what we decide to do for metadata in
419        # the exported graph module
420        if (args := graph_module.meta.get("args", None)) is not None:
421            return list(args)
422
423        def extract_input(node: torch.fx.Node) -> Optional[FakeTensor]:
424            if "val" in node.meta:
425                fake = node.meta["val"]
426                if hasattr(fake, "constant") and fake.constant is not None:
427                    return fake.constant
428                return fake
429            elif tensor_meta := node.meta.get("tensor_meta"):
430                assert self.fake_tensor_mode is not None
431                return FakeTensor(
432                    self.fake_tensor_mode,
433                    torch.empty(
434                        tensor_meta.shape,
435                        dtype=tensor_meta.dtype,
436                        device="meta",
437                        requires_grad=tensor_meta.requires_grad,
438                        memory_format=tensor_meta.memory_format,
439                    ),
440                    torch.device("cpu"),
441                )
442            elif len(node.users) == 0:
443                return None
444            raise ExportPassBaseError(
445                f"Cannot construct an input for graph module: {graph_module}.",
446            )
447
448        return [
449            extract_input(node)
450            for node in graph_module.graph.nodes
451            if node.op == "placeholder"
452        ]
453
454    def on_attr(self, attr: ProxyValue) -> None:
455        pass
456
457    def placeholder(self, name: str, arg: Argument, meta: NodeMetadata) -> ProxyValue:
458        arg_proxy = self.tracer.create_proxy("placeholder", name, (), {})
459        arg_proxy.node.meta = meta.data
460        arg_proxy.node.meta["val"] = arg
461        return ProxyValue(arg, arg_proxy)
462
463    def call_operator(
464        self,
465        op,  # pyre-ignore
466        args: Tuple[Argument, ...],
467        kwargs: Dict[str, Argument],
468        meta: NodeMetadata,
469    ) -> ProxyValue:
470        return self._fx("call_function", op, args, kwargs, meta)
471
472    def call_sym(
473        self,
474        target: Fn,
475        args: Tuple[Argument, ...],
476        meta: NodeMetadata,
477    ) -> ProxyValue:
478        return self._fx("call_function", target, args, {}, meta)
479
480    def call_cond(
481        self,
482        pred: ProxyValue,
483        true_fn: torch.fx.GraphModule,
484        false_fn: torch.fx.GraphModule,
485        inputs: List[Argument],
486        meta: NodeMetadata,
487    ) -> ProxyValue:
488        true_branch = self.call_submodule(true_fn, tuple(inputs))
489        false_branch = self.call_submodule(false_fn, tuple(inputs))
490        assert true_branch is not None
491        assert false_branch is not None
492        return self._fx(
493            "call_function",
494            torch.ops.higher_order.cond,
495            (pred, true_branch.graph_module, false_branch.graph_module, list(inputs)),
496            {},
497            meta,
498        )
499
500    def call_map(
501        self,
502        f: torch.fx.GraphModule,
503        mapped_args: List[ProxyValue],
504        operands: List[ProxyValue],
505        meta: NodeMetadata,
506    ) -> ProxyValue:
507        xs = _unstack_pytree([arg.data for arg in mapped_args])[0]
508        f_branch = self.call_submodule(f, tuple(xs + [arg.data for arg in operands]))
509        assert f_branch is not None
510        return self._fx(
511            "call_function",
512            torch.ops.higher_order.map_impl,
513            (f_branch.graph_module, mapped_args, operands),
514            {},
515            meta,
516        )
517
518    def call_getitem(
519        self, value: ProxyValue, key: int, meta: NodeMetadata
520    ) -> ProxyValue:
521        return self._fx("call_function", operator.getitem, (value, key), {}, meta)
522
523    def output(self, results: List[Argument], meta: NodeMetadata) -> ProxyValue:
524        return self._fx("output", "output", (results,), {}, meta)
525
526    def call_submodule(
527        self, graph_module: fx.GraphModule, inputs: Tuple[Argument, ...]
528    ) -> PassResult:
529        prev_tracer, self.tracer = self.tracer, self.ExportTracer(
530            self, graph_module.graph._codegen
531        )
532        self.tracer.fake_tensor_mode = prev_tracer.fake_tensor_mode
533        interpreter = self.ExportInterpreter(self, graph_module)
534        prev_interpreter, self.interpreter = self.interpreter, torch.fx.Interpreter(
535            torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
536        )
537        inputs_data = pytree.tree_map_only(ProxyValue, lambda x: x.data, inputs)
538        with fx_traceback.preserve_node_meta():
539            interpreter.run(*inputs_data)
540
541        new_graph_module = torch.fx.GraphModule(self.tracer.root, self.tracer.graph)
542
543        self.tracer = prev_tracer
544        self.interpreter = prev_interpreter
545        return PassResult(
546            new_graph_module,
547            True,
548        )
549
550    def call(self, graph_module: fx.GraphModule) -> PassResult:
551        if not getattr(self, "_initialized", False):
552            raise ExportPassBaseError(
553                "ExportPass is not initialized with __init__().",
554            )
555
556        inputs = self.inputs(graph_module)
557
558        fake_tensor_mode = None
559        for i in inputs:
560            if isinstance(i, FakeTensor):
561                assert (
562                    fake_tensor_mode is None or fake_tensor_mode is i.fake_mode
563                ), "Multiple fake tensor mode detected."
564                fake_tensor_mode = i.fake_mode
565        if fake_tensor_mode is None:
566            self.tracer.fake_tensor_mode = FakeTensorMode(allow_non_fake_inputs=True)
567            fake_tensor_mode = nullcontext()  # type: ignore[assignment]
568            dispatcher_mode = nullcontext()  # type: ignore[assignment]
569        else:
570            fake_tensor_mode.allow_non_fake_inputs = True
571            self.tracer.fake_tensor_mode = fake_tensor_mode
572            dispatcher_mode = enable_python_dispatcher()  # type: ignore[assignment]
573        self.fake_tensor_mode = self.tracer.fake_tensor_mode
574
575        with fake_tensor_mode, dispatcher_mode:  # type: ignore[assignment, union-attr]
576            result = self.call_submodule(graph_module, tuple(inputs))
577
578        return result
579
580
581class ExportPass(_ExportPassBase):
582    class ExportTracer(_ExportPassBase.ExportTracer):
583        def create_arg(self, a: Argument) -> torch.fx.Node:
584            if isinstance(a, torch.nn.Module):
585                if a not in self.submodules:
586                    prefix = "lowered_module" if is_lowered_module(a) else "submodule"
587                    name_submodule = f"{prefix}_{len(self.submodules)}"
588                    self.root.add_module(name_submodule, a)
589                    self.submodules[a] = name_submodule
590            return super().create_arg(a)
591
592    class ExportInterpreter(_ExportPassBase.ExportInterpreter):
593        """
594        Interpreter to callback on any ExportPassBase functions
595        """
596
597        def __init__(self, callback: "ExportPass", gm: fx.GraphModule) -> None:
598            super().__init__(callback, gm)
599
600        def call_function(
601            self,
602            target: torch.fx.node.Target,
603            args: Tuple[Argument, ...],
604            kwargs: Dict[str, Argument],
605        ) -> ProxyValue:
606            meta = NodeMetadata(self.node.meta)
607            if target == operator.getitem:
608                value, key = args
609                return self.callback.call_getitem(value, key, meta)
610            elif isinstance(target, EdgeOpOverload):
611                return self.callback.call_operator(
612                    target,
613                    args,
614                    kwargs,
615                    meta,
616                )
617
618            # TODO according to zhengxu ExportPassBase should not be aware of
619            # memory.alloc. Check this comment:
620            # https://www.internalfb.com/diff/D42758019?dst_version_fbid=5906016402813292&transaction_fbid=1104713900200176
621            elif target == memory.alloc:
622                return self.callback._fx(
623                    "call_function",
624                    target,
625                    args,
626                    kwargs,
627                    meta,
628                )
629
630            elif target == executorch_call_delegate:
631                lowered_module = args[0]
632                args = args[1:]
633                return self.callback.call_delegate(  # pyre-ignore
634                    lowered_module,
635                    args,
636                    kwargs,
637                    NodeMetadata(self.node.meta),
638                )
639
640            return super().call_function(target, args, kwargs)
641
642    def call_delegate(
643        self,
644        # pyre-ignore: Undefined or invalid type [11]: Annotation `LoweredBackendModule` is not defined as a type.
645        lowered_module: "LoweredBackendModule",  # noqa
646        args: Tuple[ProxyValue, ...],
647        kwargs: Dict[str, Argument],
648        meta: NodeMetadata,
649    ) -> ProxyValue:
650        args = (lowered_module,) + args
651        return self._fx(
652            "call_function",
653            executorch_call_delegate,
654            args,
655            kwargs,
656            meta,
657        )
658
659    def call_submodule(
660        self, graph_module: fx.GraphModule, inputs: Tuple[Argument, ...]
661    ) -> PassResult:
662        res = super().call_submodule(graph_module, inputs)
663
664        def preserve_original_ph_meta_val(
665            gm: torch.fx.GraphModule, new_gm: torch.fx.GraphModule
666        ) -> None:
667            def get_phs(gm: torch.fx.GraphModule) -> List[torch.fx.Node]:
668                return [node for node in gm.graph.nodes if node.op == "placeholder"]
669
670            def migrate_meta_val(
671                orig_phs: List[torch.fx.Node], new_phs: List[torch.fx.Node]
672            ) -> None:
673                if len(orig_phs) != len(new_phs):
674                    raise ExportError(
675                        ExportErrorType.NOT_SUPPORTED,
676                        "ExportPassBase doesn't support changing the placeholders",
677                    )
678                for ph, new_ph in zip(orig_phs, new_phs):
679                    if isinstance(new_ph.meta["val"], torch.Tensor):
680                        if (
681                            not isinstance(ph.meta["val"], torch.Tensor)
682                            or new_ph.meta["val"].size() != ph.meta["val"].size()
683                        ):
684                            raise ExportError(
685                                ExportErrorType.NOT_SUPPORTED,
686                                "ExportPassBase doesn't support changing the placeholders",
687                            )
688                    new_ph.meta["val"] = ph.meta["val"]
689
690            migrate_meta_val(get_phs(gm), get_phs(new_gm))
691
692        # After one pass, new_graph_module's placeholders will always hold fake tensors in
693        # meta['val'] but sometimes we want to preserve the original meta['val'] of placeholders
694        #
695        # For example, custom flows and certain passes assume no fake_tensor_mode is activated
696        # and it doesn't quite work with fake_tensor_mode. but we don't bother to fix them.
697        # So we'll just reset the meta of placeholders to its original value. It's safe because that
698        # 1. For models captured with pt2_mode, the meta['val'] of placeholders are fake_tensors already, so
699        # preserving it to the new graph module won't hurt.
700        # 2. For models captured with dispatch_trace, the meta['val'] field
701        # Note that it's only safe when passes don't modify the inputs.
702        preserve_original_ph_meta_val(graph_module, res.graph_module)
703
704        return res
705
706
707@runtime_checkable
708class ArgSchema(Protocol):
709    name: str
710    kwarg_only: bool
711    type: Any  # pyre-ignore
712
713
714def map_args(
715    op: torch._ops.OpOverload,
716    fn: Fn,
717    args: Argument,
718    kwargs: Dict[str, Argument],
719) -> Tuple[Argument, Dict[str, Argument]]:
720    assert isinstance(args, tuple)
721    assert isinstance(kwargs, dict)
722    args = list(args)
723    kwargs = kwargs.copy()
724
725    def update(key: K, args: MutableMapping[K, PyTree], schema: ArgSchema) -> None:
726        args[key] = fn(args[key], schema)
727
728    for i, schema in enumerate(op._schema.arguments):
729        if schema.name in kwargs:
730            update(schema.name, kwargs, schema)
731        elif not schema.kwarg_only and i < len(args):
732            update(i, args, schema)  # pyre-ignore
733
734    return tuple(args), kwargs
735