xref: /aosp_15_r20/external/executorch/exir/program/_program.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9import copy
10import io
11import logging
12from typing import Any, Dict, List, Optional, Sequence, Set, TextIO, Tuple, Union
13
14import torch
15import torch._export
16from executorch.exir._serialize import _serialize_pte_binary
17from executorch.exir._serialize._cord import Cord
18from executorch.exir._warnings import experimental
19from executorch.exir.backend.backend_api import to_backend
20from executorch.exir.backend.partitioner import Partitioner
21from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig
22from executorch.exir.emit import emit_program, EmitterOutput
23from executorch.exir.emit._emitter import _DelegateDebugIdentifierMap
24from executorch.exir.error import ExportError
25from executorch.exir.graph_module import get_control_flow_submodules
26from executorch.exir.pass_base import PassBase
27from executorch.exir.pass_manager import PassType
28from executorch.exir.passes import (
29    base_post_op_replace_passes,
30    base_pre_op_replace_passes,
31    dead_code_elimination_pass,
32    EdgeToBackendOpsPass,
33    MemoryFormatOpsPass,
34    OpReplacePass,
35)
36from executorch.exir.passes.insert_write_back_for_buffers_pass import (
37    insert_write_back_for_buffers_pass,
38)
39from executorch.exir.passes.normalize_view_copy_base_pass import (
40    NormalizeViewCopyBasePass,
41)
42from executorch.exir.passes.remove_graph_asserts_pass import RemoveGraphAssertsPass
43from executorch.exir.passes.remove_mixed_type_operators import RemoveMixedTypeOperators
44from executorch.exir.passes.replace_aten_with_edge_pass import aten_to_edge
45from executorch.exir.passes.replace_view_copy_with_view_pass import (
46    ReplaceViewCopyWithViewPass,
47)
48from executorch.exir.passes.spec_prop_pass import SpecPropPass
49from executorch.exir.passes.weights_to_outputs_pass import weights_to_outputs_pass
50from executorch.exir.print_program import pretty_print, print_program
51from executorch.exir.schema import Program
52from executorch.exir.tracer import _default_decomposition_table
53from executorch.exir.verification.verifier import (
54    EXIRATenDialectVerifier,
55    EXIREdgeDialectVerifier,
56    get_aten_verifier,
57)
58from torch._export.passes import ReplaceViewOpsWithViewCopyOpsPass
59from torch.export import ExportedProgram
60from torch.export._remove_auto_functionalized_pass import (
61    unsafe_remove_auto_functionalized_pass,
62)
63from torch.export.exported_program import (
64    ConstantArgument,
65    ExportGraphSignature,
66    InputKind,
67    InputSpec,
68    OutputSpec,
69    TensorArgument,
70)
71from torch.fx import _pytree as fx_pytree
72from torch.fx._compatibility import compatibility
73from torch.fx.passes.infra.pass_manager import PassManager
74from torch.utils import _pytree as pytree
75
76Val = Any
77
78from torch.library import Library
79
80# This is the reserved namespace that is used to register ops to that will
81# be prevented from being decomposed during to_edge_transform_and_lower.
82edge_no_decomp_namespace = "EDGE_DO_NOT_DECOMP"
83lib = Library(edge_no_decomp_namespace, "DEF")
84# Map from aten ops to the transformed ops registered in the edge_no_decomp_namespace.
85aten_op_to_transform_op = {}
86# Map from the transformed ops registered in the edge_no_decomp_namespace to aten ops.
87transform_op_to_aten_op = {}
88
89
90def _get_updated_range_constraints(gm):
91    def get_shape_env(gm):
92        vals = [
93            node.meta["val"]
94            for node in gm.graph.nodes
95            if node.meta.get("val", None) is not None
96        ]
97        from torch._guards import detect_fake_mode  # type: ignore[21]
98
99        fake_mode = detect_fake_mode(vals)
100        if fake_mode is not None:
101            return fake_mode.shape_env
102        for v in vals:
103            if isinstance(v, torch.SymInt):
104                return v.node.shape_env
105
106    shape_env = get_shape_env(gm)
107    if shape_env is None:
108        return {}
109    range_constraints = {
110        k: v
111        for k, v in shape_env.var_to_range.items()
112        if k not in shape_env.replacements
113    }
114    # Only when we have an unbacked symint, and it's used as constructor inputs,
115    # runtime_var_to_range will make a difference compated to var_to_range.
116    # e.g. [2, oo) -> [0, oo)
117    for k, v in shape_env.var_to_range.items():
118        if k not in shape_env.replacements:
119            range_constraints[k] = v
120    return range_constraints
121
122
123def _get_updated_graph_signature(
124    old_signature: ExportGraphSignature,
125    new_gm: torch.fx.GraphModule,
126) -> ExportGraphSignature:
127    """
128    Update the graph signature's user_input/user_outputs.
129    """
130    new_input_specs = []
131    i = 0
132    for node in new_gm.graph.nodes:
133        if node.op != "placeholder":
134            continue
135
136        assert i < len(
137            old_signature.input_specs
138        ), "Number of inputs changed after transformation"
139        old_input_spec = old_signature.input_specs[i]
140        arg = (
141            old_input_spec.arg
142            if isinstance(old_input_spec.arg, ConstantArgument)
143            # pyre-fixme[20]: Argument `class_fqn` expected.
144            else type(old_input_spec.arg)(node.name)
145        )
146        new_input_specs.append(
147            InputSpec(
148                old_input_spec.kind,
149                arg,
150                old_input_spec.target,
151                persistent=old_input_spec.persistent,
152            )
153        )
154        i += 1
155
156    output_node = list(new_gm.graph.nodes)[-1]
157    assert output_node.op == "output"
158
159    new_output_specs = []
160    for i, node in enumerate(output_node.args[0]):
161        assert i < len(
162            old_signature.output_specs
163        ), "Number of outputs changed after transformation"
164        old_output_spec = old_signature.output_specs[i]
165        arg = (
166            old_output_spec.arg
167            if isinstance(old_output_spec.arg, ConstantArgument)
168            # pyre-fixme[20]: Argument `class_fqn` expected.
169            else type(old_output_spec.arg)(node.name)
170        )
171        new_output_specs.append(
172            OutputSpec(old_output_spec.kind, arg, old_output_spec.target)
173        )
174
175    new_signature = ExportGraphSignature(
176        input_specs=new_input_specs, output_specs=new_output_specs
177    )
178    return new_signature
179
180
181def _transform(self, *passes: PassType) -> "ExportedProgram":
182    pm = PassManager(list(passes))
183    res = pm(self.graph_module)
184    transformed_gm = res.graph_module if res is not None else self.graph_module
185    assert transformed_gm is not None
186
187    if transformed_gm is self.graph_module and not res.modified:
188        return self
189
190    transformed_ep = ExportedProgram(
191        root=transformed_gm,
192        graph=transformed_gm.graph,
193        graph_signature=_get_updated_graph_signature(
194            self.graph_signature, transformed_gm
195        ),
196        state_dict=self.state_dict,
197        range_constraints=_get_updated_range_constraints(transformed_gm),
198        module_call_graph=copy.deepcopy(self._module_call_graph),
199        example_inputs=self.example_inputs,
200        constants=self.constants,
201        verifiers=[self.verifier],
202    )
203    transformed_ep.graph_module.meta.update(self.graph_module.meta)
204    transformed_ep.graph_module.meta.update(res.graph_module.meta)
205    return transformed_ep
206
207
208def _copy_module(new_prog, new_gm):
209    new_prog.meta.update(new_gm.meta)
210    new_prog.graph = new_gm.graph
211    submodules = [name for name, _ in new_prog.named_children()]
212    for name in submodules:
213        delattr(new_prog, name)
214    for name, mod in new_gm.named_children():
215        setattr(new_prog, name, mod)
216    for node in new_gm.graph.nodes:
217        if node.op == "get_attr":
218            t = getattr(new_gm, node.target, None)
219            if isinstance(t, torch.Tensor):
220                setattr(new_prog, node.target, t)
221
222
223def lift_constant_tensor_pass(ep):
224    """
225    Takes an ExportedProgram and returns the ExportedProgram modified in-place,
226    with the constant tensors as buffers.
227    """
228    if len([node for node in ep.graph.nodes if node.op == "placeholder"]) == 0:
229        return ep
230
231    graph_signature = ep.graph_signature
232    buffers = list(graph_signature.buffers)
233
234    fake_mode = list(ep.graph.nodes)[0].meta["val"].fake_mode
235    first_user_input = None
236    lifted_constants = []
237    for node in ep.graph.nodes:
238        if node.op == "placeholder" and node.name in graph_signature.user_inputs:
239            first_user_input = node
240            break
241
242    for node in ep.graph.nodes:
243        if node.op == "get_attr":
244            constant_tensor = getattr(ep.graph_module, node.target)
245            if not isinstance(constant_tensor, torch.Tensor):
246                continue
247
248            constant_tensor_fqn = f"_lifted_tensor_constant{len(buffers)}"
249
250            with ep.graph.inserting_before(first_user_input):
251                # Insert the constant node before the first user input
252                const_placeholder_node = ep.graph.placeholder(constant_tensor_fqn)
253                for k, v in node.meta.items():
254                    const_placeholder_node.meta[k] = v
255                if fake_mode is not None:
256                    const_placeholder_node.meta["val"] = fake_mode.from_tensor(
257                        constant_tensor, static_shapes=True
258                    )
259                else:
260                    const_placeholder_node.meta["val"] = constant_tensor
261                const_placeholder_node.meta["val"].constant = constant_tensor
262                node.replace_all_uses_with(const_placeholder_node)
263                ep.graph.erase_node(node)
264
265                # Add the constant as a buffer to the graph signature
266                lifted_constants.append(
267                    InputSpec(
268                        kind=InputKind.BUFFER,
269                        arg=TensorArgument(name=const_placeholder_node.name),
270                        target=constant_tensor_fqn,
271                        persistent=True,
272                    )
273                )
274                buffers.append(constant_tensor_fqn)
275                ep.state_dict[constant_tensor_fqn] = constant_tensor
276
277    new_input_specs = []
278    for s in graph_signature.input_specs:
279        if s.kind == InputKind.USER_INPUT and len(lifted_constants) > 0:
280            new_input_specs.extend(lifted_constants)
281            lifted_constants.clear()
282        new_input_specs.append(s)
283    ep.graph_signature.input_specs = new_input_specs
284    ep.graph_module.recompile()
285    return ep
286
287
288# Stub to ease migration from `transform` to private `_transform`
289def transform_exported_program(ep, *passes: PassType) -> ExportedProgram:
290    if hasattr(ep, "_transform"):
291        return ep._transform(*passes)
292    else:
293        return ep.transform(*passes)
294
295
296class HackedUpExportedProgramDONOTUSE(ExportedProgram):
297    def __init__(
298        self,
299        root,
300        graph,
301        graph_signature,
302        call_spec,
303        state_dict,
304        range_constraints,
305        module_call_graph,
306        example_inputs,
307        verifier,
308    ):
309        super().__init__(
310            root=root,
311            graph=graph,
312            graph_signature=graph_signature,
313            state_dict=state_dict,
314            range_constraints=range_constraints,
315            module_call_graph=module_call_graph,
316            example_inputs=example_inputs,
317            verifier=verifier,
318        )
319
320    def __call__(self, *args: Any, **kwargs: Any) -> Any:
321        import torch._export.error as error
322
323        if self.call_spec.in_spec is not None:
324            user_args = args
325            try:
326                args = fx_pytree.tree_flatten_spec(user_args, self.call_spec.in_spec)  # type: ignore[assignment]
327            except Exception:
328                _, received_spec = pytree.tree_flatten(user_args)
329                raise error.InternalError(
330                    "Trying to flatten user inputs with exported input tree spec: \n"
331                    f"{self.call_spec.in_spec}\n"
332                    "but actually got inputs with tree spec of: \n"
333                    f"{received_spec}"
334                )
335
336        ordered_params = tuple(
337            self.state_dict[name] for name in self.graph_signature.parameters
338        )
339        ordered_buffers = tuple(
340            self.state_dict[name] for name in self.graph_signature.buffers
341        )
342
343        with torch.no_grad():
344            # NOTE: calling convention is first params, then buffers, then args as user supplied them.
345            # See: torch/_functorch/aot_autograd.py#L1034
346            res = torch.fx.Interpreter(self.graph_module).run(
347                *ordered_params, *ordered_buffers, *args, enable_io_processing=False
348            )
349
350        if self.call_spec.out_spec is not None:
351            mutation = self.graph_signature.buffers_to_mutate
352            num_mutated = len(mutation)
353            mutated_buffers = res[:num_mutated]
354
355            # Exclude dependency token from final result.
356            assertion_dep_token = self.graph_signature.assertion_dep_token
357            if assertion_dep_token is not None:
358                assertion_dep_token_index = list(assertion_dep_token.keys())[0]
359                res = res[:assertion_dep_token_index]
360
361            res = res[num_mutated:]
362            try:
363                res = pytree.tree_unflatten(res, self.call_spec.out_spec)
364            except Exception:
365                _, received_spec = pytree.tree_flatten(res)
366                raise error.InternalError(
367                    "Trying to flatten user outputs with exported output tree spec: \n"
368                    f"{self.call_spec.out_spec}\n"
369                    "but actually got outputs with tree spec of: \n"
370                    f"{received_spec}"
371                )
372            finally:
373                ix = 0
374                for buffer in self.graph_signature.buffers_to_mutate.values():
375                    self.state_dict[buffer] = mutated_buffers[ix]
376                    ix += 1
377        return res
378
379
380@compatibility(is_backward_compatible=False)
381class ExirExportedProgram:
382    def __init__(
383        self,
384        exported_program: ExportedProgram,
385        after_to_edge_passes: bool,
386    ):
387        self.exported_program = exported_program
388
389        # Add a flag to denote whehter to_edge is called on this program
390        # to detect misusage of directly calling to_executorch without to_edge
391        self.after_to_edge_passes = after_to_edge_passes
392
393    def transform(self, *passes: PassType) -> "ExirExportedProgram":
394        self.exported_program = _transform(self.exported_program, *passes)
395        return self
396
397    def __call__(self, *args: Any) -> Any:
398        return self.exported_program.module()(*args)
399
400    # TODO(ycao): Change this to a composable function.
401    def to_edge(
402        self, config: Optional[EdgeCompileConfig] = None
403    ) -> "ExirExportedProgram":
404        config = config or EdgeCompileConfig()
405        assert isinstance(
406            self.exported_program.graph_module, torch.fx.GraphModule
407        ), f"type is instead: {type(self.exported_program.graph_module).__name__}"
408
409        return _to_edge(self, config)
410
411    def dump(self) -> None:
412        print(self.exported_program.graph_module.graph)
413
414    def to_executorch(
415        self,
416        config: Optional[ExecutorchBackendConfig] = None,
417    ) -> "ExecutorchProgram":
418        if not self.after_to_edge_passes:
419            raise RuntimeError("Must run to_edge before to_executorch.")
420        config = config or ExecutorchBackendConfig()
421        new_gm = self.exported_program.graph_module
422        for p in edge_to_executorch_passes(config):
423            new_gm_res = p(new_gm)
424            assert new_gm_res is not None
425            new_gm = new_gm_res.graph_module
426
427        # This is tech debt on tech debt. memory planning pass inherits from some pass infra for GMs.
428        # This isnt enough info now so i cant use call I have to use some new function 'run'.
429        # Existing user passes dont use run so Im just cheating here because they dont need to work on mutable buffers yet.
430        # After exir.capture is gone I will clean up the memory planning infra to be consistent.
431        # Frankly all of exir has big code quality issues because of the migrations that need to be addressed.
432        new_gm_res = config.memory_planning_pass(new_gm)  # pyre-ignore[29]
433        assert new_gm_res is not None
434        new_gm = new_gm_res.graph_module
435        new_prog = ExirExportedProgram(
436            copy.deepcopy(self.exported_program), self.after_to_edge_passes
437        )
438        _copy_module(new_prog.exported_program.graph_module, new_gm)
439        executorch_prog = ExecutorchProgram(
440            new_prog,
441            emit_stacktrace=config.emit_stacktrace,
442            extract_delegate_segments=config.extract_delegate_segments,
443            segment_alignment=config.segment_alignment,
444            constant_tensor_alignment=config.constant_tensor_alignment,
445            delegate_alignment=config.delegate_alignment,
446        )
447        executorch_prog.graph_module.meta.update(new_gm.meta)
448        executorch_prog.graph_module.meta.update(
449            self.exported_program.graph_module.meta
450        )
451        return executorch_prog
452
453    def __deepcopy__(
454        self, memo: Optional[Dict[int, Any]] = None
455    ) -> "ExirExportedProgram":
456        new_eep = ExirExportedProgram(
457            copy.deepcopy(self.exported_program, memo),
458            self.after_to_edge_passes,
459        )
460        return new_eep
461
462
463@compatibility(is_backward_compatible=False)
464class ExecutorchProgram:
465    def __init__(
466        self,
467        exir_exported_program: ExirExportedProgram,
468        emit_stacktrace: bool,
469        extract_delegate_segments: bool,
470        segment_alignment: int,
471        constant_tensor_alignment: Optional[int] = None,
472        delegate_alignment: Optional[int] = None,
473    ) -> None:
474        if not exir_exported_program.after_to_edge_passes:
475            raise RuntimeError(
476                "Need to call prog.to_edge prior to constructing ExecutorchProgram."
477            )
478        self.exported_program = exir_exported_program.exported_program
479        self._pte_data: Optional[Cord] = None
480        self._buffer: Optional[bytes] = None
481        self._emitter_output: Optional[EmitterOutput] = None
482        self._emit_stacktrace: bool = emit_stacktrace
483        self._extract_delegate_segments: bool = extract_delegate_segments
484        self._segment_alignment: int = segment_alignment
485        self._constant_tensor_alignment: Optional[int] = constant_tensor_alignment
486        self._delegate_alignment: Optional[int] = delegate_alignment
487
488    def _get_pte_data(self) -> Cord:
489        if self._pte_data is None:
490            self._pte_data = _serialize_pte_binary(
491                program=self.program,
492                extract_delegate_segments=self._extract_delegate_segments,
493                segment_alignment=self._segment_alignment,
494                constant_tensor_alignment=self._constant_tensor_alignment,
495                delegate_alignment=self._delegate_alignment,
496            )
497        return self._pte_data
498
499    @property
500    def buffer(self) -> bytes:
501        """Returns the serialized ExecuTorch binary as a byte string.
502
503        Note that the call to `buffer` may allocate a very large amount of
504        contiguous memory, depending on the model size. If writing to a file,
505        use `write_to_file` which won't incur additional copies.
506        """
507        # TODO(T181494963): update pybinding to remove buffer cache, which can consume large
508        # amounts of memory longer than necessary.
509        if self._buffer is None:
510            self._buffer = bytes(self._get_pte_data())
511        return self._buffer
512
513    @property
514    def program(self) -> Program:
515        if self._emitter_output is None:
516            self._emitter_output = emit_program(
517                self.exported_program, self._emit_stacktrace
518            )
519        return self._emitter_output.program
520
521    @property
522    def debug_handle_map(self) -> Dict[int, Union[int, List[int]]]:
523        if self._emitter_output:
524            return self._emitter_output.debug_handle_map
525        return {}
526
527    @property
528    def delegate_map(
529        self,
530    ) -> Dict[str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]]:
531        if self._emitter_output:
532            return self._emitter_output.method_to_delegate_debug_id_map
533        return {}
534
535    @property
536    def graph_module(self) -> torch.fx.GraphModule:
537        return self.exported_program.graph_module
538
539    # TODO (zhxchen17) Change this to property.
540    def dump_graph_module(self) -> torch.fx.GraphModule:
541        return self.exported_program.graph_module
542
543    def dump_exported_program(self) -> ExportedProgram:
544        return self.exported_program
545
546    def write_to_file(self, open_file: io.BufferedIOBase) -> None:
547        """
548        Writes the serialized ExecuTorch binary to the file at `open_file`. Prefer to use this over
549        `buffer`, as it writes to file without copying into a contiguous block of memory first,
550        reducing the peak memory usage.
551        """
552        self._get_pte_data().write_to_file(open_file)
553
554
555def _get_aten_to_edge_passes(config: EdgeCompileConfig):
556    # TODO: the last two passes for aten_to_edge need to be eliminated_dead_code -> debug_handle_generator. After enable
557    # use_edge_op it can be moved to aten_to_edge_passes before eliminated_dead_code pass. Also ExportPass doesn't play
558    # well with node.meta, meaning after some passes permuting operators, we may lose some information in node.meta.
559    # It might be regenerated in SpecPropPass so it may not be visiable. However debug handle will be lost.
560
561    pre_op_replace_passes = base_pre_op_replace_passes + (
562        [] if config._skip_type_promotion else [RemoveMixedTypeOperators()]
563    )
564
565    post_op_replace_passes = base_post_op_replace_passes
566
567    return pre_op_replace_passes, post_op_replace_passes
568
569
570def _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram":
571    if config._check_ir_validity:
572        try:
573            EXIRATenDialectVerifier()(ep.exported_program.graph_module)
574        except ExportError:
575            logging.info(
576                "If a particular operator failed core ATen IR check, please consider adding it to the exception list. "
577                "Add the operator to _core_aten_ops_exception_list in EdgeCompileConfig. This is the recommended way "
578                "to resolve this type of failure, so that the rest of the IR validation check can still be performed.\n"
579                "If you'd like to disable IR validation checking, please set _check_ir_validity in EdgeCompileConfig, "
580                "like *.to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))."
581            )
582            raise
583
584    dialect = ep.exported_program.dialect
585    if dialect == "ATEN":
586        ep = ExirExportedProgram(
587            ExportedProgram(
588                root=ep.exported_program.graph_module,
589                graph=ep.exported_program.graph_module.graph,
590                graph_signature=ep.exported_program.graph_signature,
591                state_dict=ep.exported_program.state_dict,
592                range_constraints=ep.exported_program.range_constraints,
593                module_call_graph=ep.exported_program.module_call_graph,
594                example_inputs=ep.exported_program.example_inputs,
595                constants=ep.exported_program.constants,
596                verifiers=[
597                    get_aten_verifier(
598                        config=config,
599                    )
600                ],
601            ),
602            False,
603        )
604    pre_op_replace_passes, post_op_replace_passes = _get_aten_to_edge_passes(config)
605
606    new_ep = copy.deepcopy(ep).transform(*pre_op_replace_passes)
607    if dialect == "ATEN":
608        new_ep.exported_program = lift_constant_tensor_pass(new_ep.exported_program)
609
610    new_gm = new_ep.exported_program.graph_module
611    if config._use_edge_ops:
612        new_gm_res = OpReplacePass()(new_gm)
613        assert new_gm_res is not None
614        new_gm = new_gm_res.graph_module
615        if not config._skip_dim_order:
616            new_gm_res = MemoryFormatOpsPass()(new_gm)
617            assert new_gm_res is not None
618            new_gm = new_gm_res.graph_module
619
620    for p in post_op_replace_passes:
621        new_gm_res = p(new_gm)
622        assert new_gm_res is not None
623        new_gm = new_gm_res.graph_module
624
625    new_ep.exported_program = ExportedProgram(
626        root=new_gm,
627        graph=new_gm.graph,
628        graph_signature=_get_updated_graph_signature(
629            new_ep.exported_program.graph_signature, new_gm
630        ),
631        state_dict=new_ep.exported_program.state_dict,
632        range_constraints=new_ep.exported_program.range_constraints,
633        module_call_graph=new_ep.exported_program.module_call_graph,
634        example_inputs=new_ep.exported_program.example_inputs,
635        constants=new_ep.exported_program.constants,
636        verifiers=[
637            EXIREdgeDialectVerifier(
638                edge_compile_config=config,
639                class_only=True,
640            )
641        ],
642    )
643    new_ep.after_to_edge_passes = True
644    return new_ep
645
646
647def pre_memory_planning_passes(
648    config: ExecutorchBackendConfig, name: Optional[str] = None
649) -> List[PassType]:
650    """
651    Returns a list of passes to run before memory planning.
652    Get the sym shape eval pass based on the method name, if the pass is not in the dict, use the default pass.
653    """
654    # Handle symbolic shape eval pass
655    if isinstance(config.sym_shape_eval_pass, dict):
656        default_pass = ExecutorchBackendConfig().sym_shape_eval_pass
657        if not name:
658            sym_shape_eval_pass = default_pass
659        # pyre-ignore: Undefined attribute [16]
660        sym_shape_eval_pass = config.sym_shape_eval_pass.get(name, default_pass)
661    elif isinstance(config.sym_shape_eval_pass, PassBase):
662        sym_shape_eval_pass = config.sym_shape_eval_pass
663    else:
664        raise RuntimeError(
665            f"sym_shape_eval_pass must be a dict or a PassBase, got {config.sym_shape_eval_pass}"
666        )
667    if config.remove_view_copy:
668        return [
669            NormalizeViewCopyBasePass(),
670            dead_code_elimination_pass,
671            ReplaceViewCopyWithViewPass(),
672            sym_shape_eval_pass,
673            config.to_out_var_pass,
674        ]
675    else:
676        return [
677            sym_shape_eval_pass,
678            config.to_out_var_pass,
679        ]
680
681
682def edge_to_executorch_passes(
683    config: ExecutorchBackendConfig, name: Optional[str] = None
684) -> List[PassType]:
685    """
686    Returns a list of passes to lower from edge to executorch.
687    Get the pre memory planning passes based on the method name, if the pass is not in the dict, use the default pass.
688    """
689    passes: List[PassType] = [
690        *config.passes,
691        SpecPropPass(),
692        # ExecuTorch backend ops are unable to handle unbacked symints. So after
693        # this pass, passes cannot be Interpreter-based, because it will fail if
694        # there exists an unbacked symint operation.
695        EdgeToBackendOpsPass(),
696        RemoveGraphAssertsPass(),
697    ] + pre_memory_planning_passes(config, name)
698
699    return passes
700
701
702def _generate_edge_program(
703    name: str,
704    config: EdgeCompileConfig,
705    program: ExportedProgram,
706    ops_set_to_not_decompose: Optional[List[torch._ops.OpOverload]] = None,
707) -> ExportedProgram:
708    if config._check_ir_validity:
709        try:
710            EXIRATenDialectVerifier(
711                edge_compile_config=config,
712                class_only=False,
713                exception_list=ops_set_to_not_decompose,
714            )(program.graph_module)
715        except ExportError as e:
716            logging.info(f"Input program {name} is not in ATen dialect.")
717            raise e
718
719    pre_op_replace_passes, post_op_replace_passes = _get_aten_to_edge_passes(config)
720
721    passes = []
722    passes.append(
723        ReplaceViewOpsWithViewCopyOpsPass()
724    )  # TODO move inside aten_to_edge passes after all users are migrated off v1 capture
725    passes.extend(pre_op_replace_passes)
726    if config._use_edge_ops:
727        passes.append(OpReplacePass())
728        if not config._skip_dim_order:
729            passes.append(MemoryFormatOpsPass())
730
731    gm = program.graph_module
732    for p in passes:
733        gm_res = p(gm)
734        assert gm_res is not None
735        gm = gm_res.graph_module
736
737    edge_program = ExportedProgram(
738        root=gm,
739        graph=gm.graph,
740        graph_signature=_get_updated_graph_signature(program.graph_signature, gm),
741        state_dict=program.state_dict,
742        range_constraints=program.range_constraints,
743        module_call_graph=program.module_call_graph,
744        example_inputs=program.example_inputs,
745        constants=program.constants,
746        verifiers=[
747            EXIREdgeDialectVerifier(
748                edge_compile_config=config,
749                class_only=True,
750                exception_list=ops_set_to_not_decompose,
751            )
752        ],
753    )
754    # Lift the tensor constants created in ScalarToTensorPass
755    edge_program = lift_constant_tensor_pass(edge_program)
756    edge_program = _transform(edge_program, *post_op_replace_passes)
757
758    return edge_program
759
760
761def _replace_aten_ops_with_transformed_ops(
762    name: str,
763    program: ExportedProgram,
764    partitioner,
765):
766    ops_to_not_decompose = set()
767    partitioners = partitioner.get(name)
768    if partitioners is None:
769        return
770
771    # Iterate through the graph and replace the aten ops with the corresponding
772    # transformed ops.
773    for partitioner in partitioners:
774        ops_set_to_not_decompose, check_op_support = partitioner.ops_to_not_decompose(
775            program
776        )
777
778        for op_aten in ops_set_to_not_decompose:
779            _register_no_decomp_op(op_aten)
780
781        for node in program.graph.nodes:
782            is_op_supported = check_op_support(node) if check_op_support else True
783            if (
784                node.op == "call_function"
785                and node.target in ops_set_to_not_decompose
786                and is_op_supported
787            ):
788                ops_to_not_decompose.add(node.target)
789                node.target = aten_op_to_transform_op[node.target]
790
791        for _, submod, _ in get_control_flow_submodules(program.graph_module):
792            for node in submod.graph.nodes:
793                is_op_supported = check_op_support(node) if check_op_support else True
794                if (
795                    node.op == "call_function"
796                    and node.target in ops_set_to_not_decompose
797                    and is_op_supported
798                ):
799                    ops_to_not_decompose.add(node.target)
800                    node.target = aten_op_to_transform_op[node.target]
801
802    return ops_to_not_decompose
803
804
805def _restore_transformed_ops_to_aten_ops(program: ExportedProgram):
806    # Iterate through the graph and replace back the transformed ops with their
807    # corresponding aten ops.
808    for node in program.graph.nodes:
809        if node.op == "call_function" and str(node.target) in transform_op_to_aten_op:
810            node.target = transform_op_to_aten_op[str(node.target)]
811    for _, submod, _ in get_control_flow_submodules(program.graph_module):
812        for node in submod.graph.nodes:
813            if (
814                node.op == "call_function"
815                and str(node.target) in transform_op_to_aten_op
816            ):
817                node.target = transform_op_to_aten_op[str(node.target)]
818
819
820# Returns the op in edge_no_decomp_namespace namespace for the aten
821# op that is passed in.
822def _get_transformed_op(op_aten):
823    op_name = op_aten._schema.name.split("::")[1]
824    overload_name = op_aten._schema.overload_name
825    assert hasattr(
826        torch.ops, edge_no_decomp_namespace
827    ), f"Couldn't find {edge_no_decomp_namespace} in torch.ops. Please make sure the Library has been registered."
828    op_namespace = getattr(torch.ops, edge_no_decomp_namespace)
829    op = getattr(op_namespace, op_name)
830    return getattr(op, overload_name)
831
832
833# Registers the op in edge_no_decomp_namespace namespace for the aten
834# op that is passed in if it is not already cached in the table.
835def _register_no_decomp_op(op_aten):
836    # Check if the op is already cached in the table. If not, then we need to
837    # create a new op in the edge_no_decomp_namespace namespace.
838    if aten_op_to_transform_op.get(op_aten) is None and isinstance(
839        op_aten, torch._ops.OpOverload
840    ):
841        # Extract the schema from the aten op.
842        op_schema = str(op_aten._schema).split("::")[1]
843        op_name = op_aten._schema.name.split("::")[1]
844        # Define an op in the edge_no_decomp_namespace namespace with the aten schema.
845        lib.define(op_schema)
846        # Define the implementation of the op in the edge_no_decomp_namespace namespace.
847        # Important to note that the implementation of the op is the same as the aten op.
848
849        overload_name = op_aten._schema.overload_name
850        if overload_name != "":
851            op_name += "." + overload_name
852        lib.impl(op_name, op_aten, "CompositeExplicitAutograd")
853
854        # Cache the aten op and transformed op in their corresponding tables for future use.
855        aten_op_to_transform_op[op_aten] = _get_transformed_op(op_aten)
856        transform_op_to_aten_op[str(aten_op_to_transform_op[op_aten])] = op_aten
857
858
859def _sanity_check_graph_for_non_decomp_ops(
860    name: str,
861    program: ExportedProgram,
862    ops_set_to_not_decompose,
863    check_op_support,
864    generate_error=False,
865    partitioner_name=None,
866):
867    warning_str = f"Found {ops_set_to_not_decompose} in edge dialect program {name}."
868    if partitioner_name is not None:
869        warning_str += f" This op was registered by the partitioner {partitioner_name} to not be decomposed."
870
871    # Check that the ops that were registered to not be decomposed are not present in the
872    # graph anymore as the transform passes and backends should have consumed them by now.
873    ops_set_to_not_decompose = {
874        aten_to_edge(op) for op in ops_set_to_not_decompose
875    }.union(ops_set_to_not_decompose)
876    for node in program.graph_module.graph.nodes:
877        is_op_supported = check_op_support(node) if check_op_support else True
878        if (
879            node.op == "call_function" and node.target in ops_set_to_not_decompose
880        ) and is_op_supported:
881            if generate_error:
882                raise RuntimeError(warning_str)
883            else:
884                logging.warning(warning_str)
885    for _, submod, _ in get_control_flow_submodules(program.graph_module):
886        for node in submod.graph.nodes:
887            is_op_supported = check_op_support(node) if check_op_support else True
888            if (
889                node.op == "call_function" and node.target in ops_set_to_not_decompose
890            ) and is_op_supported:
891                if generate_error:
892                    raise RuntimeError(warning_str)
893                else:
894                    logging.warning(warning_str)
895
896
897def _gen_edge_manager_for_partitioners(
898    partitioner: Dict[str, List[Partitioner]],
899    aten_programs: Dict[str, ExportedProgram],
900    config: EdgeCompileConfig,
901    constant_methods: Optional[Dict[str, Any]],
902) -> "EdgeProgramManager":
903    """
904    Generates EdgeProgramManager for subsequent lowering to the
905    partitioners specified by partitioner. The EdgeProgramManager is generated from
906    aten_programs.
907
908    Partitioners specify what nodes should not be decomposed from the original aten programs.
909    This is done through two passes of run_decompositions.
910        - First pass preserves all aten_targets specified by partitioners to preserve
911          them from nested decompositions
912        - Second pass uses check_op fn provided by partitioners to perform additional checks
913          on nodes with preserved aten targets. They are then replaces with transformed ops to
914          keep them through the second pass of decompositions
915    """
916    ops_set_to_not_decompose_by_program = {}
917    edge_programs: Dict[str, ExportedProgram] = {}
918    for name, program in aten_programs.items():
919        if partitioner is not None:
920            # preserve all ops listed by all partitioners first
921            all_ops_no_decomp = set()
922            for curr_partitioner in partitioner.get(name, []):
923                curr_ops_no_decomp, _ = curr_partitioner.ops_to_not_decompose(program)
924                all_ops_no_decomp |= set(curr_ops_no_decomp)
925
926            table = _default_decomposition_table()
927
928            for op in all_ops_no_decomp:
929                table.pop(op, None)
930
931            program = program.run_decompositions(table)
932            # Among all the preserved aten ops, use the check_op_fn to do an additional
933            # check on which ops need to be preserved and which ops need to be decomposed
934            # Those which are truly preserved will be replaced with transformed ops
935            ops_set_to_not_decompose_by_program[name] = (
936                _replace_aten_ops_with_transformed_ops(name, program, partitioner) or []
937            )
938        program = program.run_decompositions(_default_decomposition_table())
939
940        _restore_transformed_ops_to_aten_ops(program)
941
942        edge_programs[name] = program
943
944        edge_programs[name] = _generate_edge_program(
945            name,
946            config,
947            program,
948            list(ops_set_to_not_decompose_by_program.get(name, [])),
949        )
950
951    edge_manager = EdgeProgramManager(
952        edge_programs,
953        constant_methods,
954        config,
955        list(set().union(*ops_set_to_not_decompose_by_program.values())),
956    )
957    return edge_manager
958
959
960def to_edge_transform_and_lower(
961    programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
962    transform_passes: Optional[
963        Union[Sequence[PassType], Dict[str, Sequence[PassType]]]
964    ] = None,
965    partitioner: Optional[
966        Union[List[Partitioner], Dict[str, List[Partitioner]]]
967    ] = None,
968    constant_methods: Optional[Dict[str, Any]] = None,
969    compile_config: Optional[EdgeCompileConfig] = None,
970) -> "EdgeProgramManager":
971    """
972    :func:`to_edge_transform_and_lower` constructs an EdgeProgramManager from a set of
973    exported programs in ATen dialect. It differs fundamentally from to_edge in that it
974    combines the conversion of the ATen dialect to the edge dialect program, then running
975    the transformation passes and then subsequently lowering the programs to their
976    corresponding backends all into a single API.
977
978    This is fundamentally useful for lowering to backends that have ops registered that they
979    do not want to be decomposed and thus rely on matching with these non-decomposed ops. For
980    these sorts of backends this is the *only* API that should be used to lower to the edge
981    dialect. Using a combination of to_edge(...) and to_backend(...) will result in inconsistent
982    or wrong behavior.
983
984    This API is the primary recommended way to lower to the CPU based XNNPack backend.
985
986    Args:
987        programs: Can be a single ExportedProgram or a dictionary mapping function names
988            to their corresponding ExportedPrograms. If only a single ExportedProgram is
989            provided it will be assigned the name "forward".
990
991        transform_passes: The passes can either be a list of passes, or a dictionary
992            mapping method names to lists of passes. If it is just a list of passes, all methods
993            in the given EdgeProgramManager will be transformed with the provided passes. If it
994            is a dictionary, only method names specified in the dictionary will be transformed
995            with their corresponding passes.
996
997        partitioner: The partitioner can either be a Partitioner subclass instance, or a
998            dictionary mapping method names to Partitioner subclass instance. If it is a
999            Partitioner subclass, all programs in the given EdgeProgramManager will be lowered
1000            using the given partitioner. If it is a dictionary, only method names specified in
1001            the dictionary will be lowered with the given partitioner.
1002
1003        constant_methods: An optional dictionary of method name to the constant value
1004            returned by that method in eager mode. Often used to store config information on
1005            Edge models.
1006
1007        compile_config: An optional argument used to provide greater control over the
1008            transformation to edge dialect process.
1009
1010    Returns:
1011        EdgeProgramManager
1012    """
1013    assert not isinstance(constant_methods, EdgeCompileConfig)
1014    config = compile_config or EdgeCompileConfig()
1015    if not isinstance(programs, dict):
1016        aten_programs = {"forward": programs}
1017    else:
1018        aten_programs = programs
1019
1020    if not isinstance(partitioner, dict) and partitioner is not None:
1021        partitioner = {name: partitioner for name in aten_programs.keys()}
1022    elif partitioner is None:
1023        partitioner = {name: [] for name in aten_programs.keys()}
1024
1025    edge_manager = _gen_edge_manager_for_partitioners(
1026        partitioner, aten_programs, config, constant_methods
1027    )
1028
1029    if transform_passes is not None:
1030        edge_manager = edge_manager.transform(transform_passes)
1031
1032    if partitioner is not None:
1033        for name, partitioner_list in partitioner.items():
1034            for curr_partitioner in partitioner_list:
1035                edge_manager = edge_manager.to_backend({name: curr_partitioner})
1036
1037    for name, program in edge_manager._edge_programs.items():
1038        ops_set_to_not_decompose: Set[torch._ops.OpOverload] = set()
1039        partitioners = partitioner.get(name, [])
1040        for curr_partitioner in partitioners:
1041            curr_op_set, check_op_support = curr_partitioner.ops_to_not_decompose(
1042                program
1043            )
1044            ops_set_to_not_decompose = ops_set_to_not_decompose.union(curr_op_set)
1045            _sanity_check_graph_for_non_decomp_ops(
1046                name,
1047                program,
1048                ops_set_to_not_decompose,
1049                check_op_support,
1050                partitioner_name=curr_partitioner.__class__.__name__,
1051                generate_error=True,
1052            )
1053
1054        if config._check_ir_validity:
1055            EXIREdgeDialectVerifier(
1056                edge_compile_config=config,
1057                class_only=True,
1058                exception_list=list(ops_set_to_not_decompose),
1059            )()(program.graph_module)
1060
1061    return edge_manager
1062
1063
1064@experimental(
1065    """
1066    This is an experimental API which overloads to_edge by preserving specified ops to not be decomposed.
1067    This function will be combined with to_edge in the future.
1068    """
1069)
1070def to_edge_with_preserved_ops(
1071    programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
1072    constant_methods: Optional[Dict[str, Any]] = None,
1073    compile_config: Optional[EdgeCompileConfig] = None,
1074    preserve_ops: Tuple[torch._ops.OpOverload, ...] = (),
1075) -> "EdgeProgramManager":
1076    """
1077    :func:`to_edge` constructs an EdgeProgramManager from a set of exported programs in
1078    ATen dialect. Upon construction those programs are transformed into edge dialect.
1079
1080    Args:
1081        programs: Can be a single ExportedProgram or a dictionary mapping function names to their corresponding ExportedPrograms. If only a single ExportedProgram is provided it will be assigned the name "forward".
1082        constant_methods: An optional dictionary of method name to the constant value returned by that method in eager mode. Often used to store config information on Edge models.
1083        compile_config: An optional argument used to provide greater control over the transformation to edge dialect process.
1084        preserve_ops: An argument used to specify ops that should not be decomposed.
1085
1086    Returns:
1087        EdgeProgramManager
1088    """
1089    assert not isinstance(constant_methods, EdgeCompileConfig)
1090    config = compile_config or EdgeCompileConfig()
1091    if not isinstance(programs, dict):
1092        aten_programs = {"forward": programs}
1093    else:
1094        aten_programs = programs
1095
1096    edge_programs: Dict[str, ExportedProgram] = {}
1097
1098    for name, program in aten_programs.items():
1099        # Decompose to Core ATen
1100        table = _default_decomposition_table()
1101        for op in preserve_ops:
1102            table.pop(op, None)
1103        program = program.run_decompositions(table)
1104        edge_programs[name] = _generate_edge_program(
1105            name, config, program, list(preserve_ops)
1106        )
1107
1108    return EdgeProgramManager(
1109        edge_programs, constant_methods, config, list(preserve_ops)
1110    )
1111
1112
1113def to_edge(
1114    programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
1115    constant_methods: Optional[Dict[str, Any]] = None,
1116    compile_config: Optional[EdgeCompileConfig] = None,
1117) -> "EdgeProgramManager":
1118    """
1119    :func:`to_edge` constructs an EdgeProgramManager from a set of exported programs in
1120    ATen dialect. Upon construction those programs are transformed into edge dialect.
1121
1122    Args:
1123        programs: Can be a single ExportedProgram or a dictionary mapping function names to their corresponding ExportedPrograms. If only a single ExportedProgram is provided it will be assigned the name "forward".
1124
1125        constant_methods: An optional dictionary of method name to the constant value returned by that method in eager mode. Often used to store config information on Edge models.
1126
1127        compile_config: An optional argument used to provide greater control over the transformation to edge dialect process.
1128
1129    Returns:
1130        EdgeProgramManager
1131    """
1132    assert not isinstance(constant_methods, EdgeCompileConfig)
1133    config = compile_config or EdgeCompileConfig()
1134    if not isinstance(programs, dict):
1135        aten_programs = {"forward": programs}
1136    else:
1137        aten_programs = programs
1138
1139    edge_programs: Dict[str, ExportedProgram] = {}
1140
1141    for name, program in aten_programs.items():
1142        # Decompose to Core ATen
1143        program = program.run_decompositions(_default_decomposition_table())
1144        edge_programs[name] = _generate_edge_program(name, config, program)
1145
1146    return EdgeProgramManager(edge_programs, constant_methods, config)
1147
1148
1149class EdgeProgramManager:
1150    """
1151    Package of one or more `ExportedPrograms` in Edge dialect. Designed to simplify
1152    lowering to ExecuTorch. See: https://pytorch.org/executorch/stable/ir-exir.html
1153
1154    Allows easy applications of transforms across a collection of exported programs
1155    including the delegation of subgraphs.
1156
1157    Manages the second link in the lowering chain of ATen -> Edge -> ExecuTorch.
1158    """
1159
1160    def __init__(
1161        self,
1162        edge_programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
1163        constant_methods: Optional[Dict[str, Any]] = None,
1164        compile_config: Optional[EdgeCompileConfig] = None,
1165        ops_set_to_not_decompose: Optional[List[torch._ops.OpOverload]] = None,
1166    ):
1167        """
1168        Should not be called directly by users. User should use :func:'to_edge' instead.
1169
1170        Constructs an EdgeProgramManager from an existing set of exported programs in edge dialect.
1171        """
1172        self.compile_config = compile_config or EdgeCompileConfig()
1173        if not isinstance(edge_programs, dict):
1174            edge_programs = {"forward": edge_programs}
1175
1176        for name, program in edge_programs.items():
1177            try:
1178                EXIREdgeDialectVerifier(
1179                    edge_compile_config=self.compile_config,
1180                    exception_list=ops_set_to_not_decompose,
1181                )(program.graph_module)
1182            except ExportError as e:
1183                logging.info(f"Input program {name} is not in aten dialect.")
1184                raise e
1185
1186        self._edge_programs: Dict[str, ExportedProgram] = edge_programs
1187        self._config_methods = constant_methods
1188
1189    @property
1190    def methods(self) -> Set[str]:
1191        """
1192        Returns the set of methods in this EdgeProgramManager.
1193        """
1194        return set(self._edge_programs.keys())
1195
1196    @property
1197    def config_methods(self) -> Set[str]:
1198        """
1199        Returns the set of config methods in this EdgeProgramManager.
1200        """
1201        return set(self._config_methods.keys()) if self._config_methods else set()
1202
1203    def exported_program(self, method_name: str = "forward") -> ExportedProgram:
1204        """
1205        Returns the ExportedProgram specified by 'method_name'.
1206        """
1207        return self._edge_programs[method_name]
1208
1209    def transform(
1210        self,
1211        passes: Union[Sequence[PassType], Dict[str, Sequence[PassType]]],
1212        compile_config: Optional[EdgeCompileConfig] = None,
1213    ) -> "EdgeProgramManager":
1214        """
1215        Transforms the program according to the provided passes.
1216
1217        Args:
1218            passes: The passes can either be a list of passes, or a
1219                dictionary mapping method names to lists of passes. If it is
1220                just a list of passes, all methods in the given EdgeProgramManager
1221                will be transformed with the provided passes. If it is a
1222                dictionary, only method names specified in the dictionary will be
1223                transformed with their corresponding passes.
1224            compile_config: Compile config to use for veriy the correctness of model
1225                graph after each pass. If not specified, the compile config of the
1226                calling EdgeProgramManager will be used. It will be used in as compile
1227                config of returned EdgeProgramManager.
1228
1229        Returns:
1230            EdgeProgramManager: A copy of the calling EdgeProgramManager with the
1231            transformations applied.
1232        """
1233        compile_config = compile_config or self.compile_config
1234        new_programs: Dict[str, ExportedProgram] = {}
1235        if isinstance(passes, dict):
1236            for name, program in self._edge_programs.items():
1237                if name in passes.keys():
1238                    new_programs[name] = _transform(program, *passes[name])
1239                    EXIREdgeDialectVerifier(edge_compile_config=compile_config)(
1240                        new_programs[name].graph_module
1241                    )
1242                else:
1243                    new_programs[name] = copy.deepcopy(program)
1244
1245        else:  # apply passes to every method
1246            for name, program in self._edge_programs.items():
1247                new_programs[name] = _transform(program, *passes)
1248                EXIREdgeDialectVerifier(edge_compile_config=compile_config)(
1249                    new_programs[name].graph_module
1250                )
1251
1252        return EdgeProgramManager(
1253            new_programs, copy.deepcopy(self._config_methods), compile_config
1254        )
1255
1256    def to_backend(
1257        self, partitioner: Union[Partitioner, Dict[str, Partitioner]]
1258    ) -> "EdgeProgramManager":
1259        """
1260        Returns a semantically-equivalent program to the one given as input,
1261        but with portions of each program in the EdgeProgramManager targeted
1262        for delegation as determined by the partitioner.
1263
1264        Args:
1265            partitioner: The partitioner can either be a Partitioner subclass instance, or a
1266                dictionary mapping method names to Partitioner subclass instance. If it is a
1267                Partitioner subclass, all programs in the given EdgeProgramManager
1268                will be lowered using the given partitioner. If it is a
1269                dictionary, only method names specified in the dictionary will be
1270                lowered with the given partitioner.
1271
1272                The Partitioner subclass instance is in charge with tagging portions of the
1273                input program for delegation. A valid partitioner must return PartitionerResult including valid
1274                partition_tags: Dict[str, DelegationSpec], where each key is a tag
1275                name and the nodes with same tag will be fused a one subgraph and
1276                delegated to backend specififed in delegation spec.
1277
1278        Returns:
1279            EdgeProgramManager: A copy of the calling EdgeProgramManager with the
1280            specified subgraphs lowered.
1281        """
1282        new_edge_programs: Dict[str, ExportedProgram] = {}
1283        if isinstance(partitioner, dict):
1284            for name, program in self._edge_programs.items():
1285                if name in partitioner.keys():
1286                    new_edge_programs[name] = to_backend(program, partitioner[name])
1287                else:
1288                    new_edge_programs[name] = program
1289
1290        else:  # apply partitioner to every method
1291            for name, program in self._edge_programs.items():
1292                new_edge_programs[name] = to_backend(program, partitioner)
1293
1294        config = EdgeCompileConfig(_check_ir_validity=False)
1295        return EdgeProgramManager(
1296            new_edge_programs, copy.deepcopy(self._config_methods), config
1297        )
1298
1299    def to_executorch(
1300        self,
1301        config: Optional[ExecutorchBackendConfig] = None,
1302    ) -> "ExecutorchProgramManager":
1303        """
1304        Transforms the program to the ExecuTorch backend.
1305
1306        Args:
1307            config: An optional argument used to provide greater control over
1308                the transformation to the ExecuTorch backend.
1309
1310        Returns:
1311            ExecutorchProgramManager: A manager representing the state of the EdgeProgramManager
1312            after it has been transformed to the ExecuTorch backend.
1313        """
1314        config = config if config else ExecutorchBackendConfig()
1315
1316        execution_programs: Dict[str, ExportedProgram] = {}
1317        for name, program in self._edge_programs.items():
1318            program = weights_to_outputs_pass(program)
1319            program = unsafe_remove_auto_functionalized_pass(program)
1320            gm, new_signature = insert_write_back_for_buffers_pass(program)
1321            new_gm = program.graph_module
1322            for p in edge_to_executorch_passes(config, name):
1323                new_gm_res = p(new_gm)
1324                assert new_gm_res is not None
1325                new_gm = new_gm_res.graph_module
1326                if isinstance(p, SpecPropPass):
1327                    # Note that this is a hacky way to get around the fact that
1328                    # placeholder nodes corresponding to the parameters of the graph module
1329                    # shall not participate in memory planning. It increases runtime memory
1330                    # footprint.
1331                    # Proper way would be to have ExportPass work with ExportedProgram
1332                    # instead of GraphModule. This is because ExportPass should work
1333                    # on top of the export artifact of torch.export whichi s ExportedProgram.
1334                    # Working with GraphModule does not provide all the information contained
1335                    # in the ExportedProgram
1336                    # TODO(who?)
1337                    p.update_placeholder_tensor_specs(program, new_gm)
1338
1339            if isinstance(config.memory_planning_pass, dict):
1340                memory_planning_pass = config.memory_planning_pass.get(
1341                    name, ExecutorchBackendConfig().memory_planning_pass
1342                )
1343            else:
1344                memory_planning_pass = config.memory_planning_pass
1345            # TODO(jakeszwe): Follow up with compiler on if the deepcopy is necessary and if so how to make it work
1346            if hasattr(memory_planning_pass, "run"):
1347                new_gm_res = memory_planning_pass.run(  # pyre-ignore[16]
1348                    new_gm, new_signature
1349                )
1350            else:
1351                new_gm_res = memory_planning_pass(new_gm)  # pyre-ignore[29]
1352            assert new_gm_res is not None
1353            new_gm = new_gm_res.graph_module
1354
1355            _copy_module(program.graph_module, new_gm)
1356            execution_programs[name] = program
1357
1358        return ExecutorchProgramManager(
1359            execution_programs, self._config_methods, config
1360        )
1361
1362
1363class ExecutorchProgramManager:
1364    """
1365    Package of one or more `ExportedPrograms` in Execution dialect. Designed to simplify
1366    lowering to ExecuTorch. See: https://pytorch.org/executorch/stable/ir-exir.html
1367
1368    When the ExecutorchProgramManager is constructed the ExportedPrograms in execution dialect
1369    are used to form the executorch binary (in a process called emission) and then serialized
1370    to a buffer.
1371
1372    Manages the final link in the lowering chain of ATen -> Edge -> ExecuTorch.
1373    """
1374
1375    def __init__(
1376        self,
1377        execution_programs: Dict[str, ExportedProgram],
1378        config_methods: Optional[Dict[str, Any]] = None,
1379        backend_config: Optional[ExecutorchBackendConfig] = None,
1380    ):
1381        """
1382        End users should not call this constructor directly. Instead, they should use
1383        :func:'to_executorch' to construct an ExecutorchProgramManager.
1384
1385        Constructs an ExecutorchProgramManager from a set of exported programs in
1386        execution dialect.
1387
1388        Args:
1389            execution_programs: A dictionary of method name to the corresponding
1390            ExportedProgram.
1391
1392            config_methods: A dictionary of method name to the config value returned
1393            by that method in eager mode.
1394
1395            backend_config: An optional argument used to provide greater control over
1396            the emission and serialization.
1397        """
1398        # Set up methods
1399        self._execution_programs: Dict[str, ExportedProgram] = execution_programs
1400        self._config_methods: Optional[Dict[str, Any]] = config_methods
1401
1402        backend_config = backend_config or ExecutorchBackendConfig()
1403
1404        # Emit methods
1405        self._emitter_output: EmitterOutput = emit_program(
1406            self._execution_programs,
1407            backend_config.emit_stacktrace,
1408            self._config_methods,
1409        )
1410
1411        # Serialize emitter output, ready to be written to a file.
1412        self._pte_data: Cord = _serialize_pte_binary(
1413            program=self._emitter_output.program,
1414            mutable_data=self._emitter_output.mutable_data,
1415            extract_delegate_segments=backend_config.extract_delegate_segments,
1416            segment_alignment=backend_config.segment_alignment,
1417            constant_tensor_alignment=backend_config.constant_tensor_alignment,
1418            delegate_alignment=backend_config.delegate_alignment,
1419        )
1420        self._buffer: Optional[bytes] = None
1421
1422    @property
1423    def methods(self) -> Set[str]:
1424        """
1425        Returns the set of methods in this ExecutorchProgramManager.
1426        """
1427        return set(self._execution_programs.keys())
1428
1429    @property
1430    def config_methods(self) -> Set[str]:
1431        """
1432        Returns the set of config methods in this ExecutorchProgramManager.
1433        """
1434        return set(self._config_methods.keys()) if self._config_methods else set()
1435
1436    def exported_program(self, method_name: str = "forward") -> ExportedProgram:
1437        """
1438        Returns the ExportedProgram specified by 'method_name'.
1439        """
1440        return self._execution_programs[method_name]
1441
1442    def dump_executorch_program(
1443        self, verbose: bool = False, out: Optional[TextIO] = None
1444    ) -> None:
1445        """
1446        Prints the ExecuTorch binary in a human readable format.
1447
1448        Args:
1449            verbose (bool):
1450                If False prints the binary in a condensed format.
1451                If True prints the binary 1-1 with the specification in the schema.
1452            out:
1453                If None, prints to stdout.
1454                If non-None, writes the string to that stream object. It can be
1455                    a file object, a StringIO object, or any other TextIO subclass.
1456        """
1457        if verbose:
1458            pretty_print(self._emitter_output.program, out=out)
1459        else:
1460            print_program(self._emitter_output.program, out=out)
1461
1462    @property
1463    def debug_handle_map(self) -> Dict[int, Union[int, List[int]]]:
1464        return self._emitter_output.debug_handle_map
1465
1466    @property
1467    def delegate_map(
1468        self,
1469    ) -> Dict[str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]]:
1470        return self._emitter_output.method_to_delegate_debug_id_map
1471
1472    @property
1473    def executorch_program(self) -> Program:
1474        """
1475        Returns the object that represents the ExecuTorch binary before serialization.
1476        """
1477        return self._emitter_output.program
1478
1479    @property
1480    def buffer(self) -> bytes:
1481        """Returns the serialized ExecuTorch binary as a byte string.
1482
1483        Note that the call to `buffer` may allocate a very large amount of
1484        contiguous memory, depending on the model size. If writing to a file,
1485        use `write_to_file` which won't incur additional copies.
1486        """
1487        # TODO(T181494963): update pybinding to remove buffer cache, which can consume large
1488        # amounts of memory longer than necessary.
1489        if self._buffer is None:
1490            self._buffer = bytes(self._pte_data)
1491        return self._buffer
1492
1493    def write_to_file(self, open_file: io.BufferedIOBase) -> None:
1494        """
1495        Writes the serialized ExecuTorch binary to the file at `open_file`. Prefer to use this over
1496        `buffer`, as it writes to file without copying into a contiguous block of memory first,
1497        reducing the peak memory usage.
1498        """
1499        self._pte_data.write_to_file(open_file)
1500