xref: /aosp_15_r20/external/executorch/exir/program/_program.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker# pyre-unsafe
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Workerimport copy
10*523fa7a6SAndroid Build Coastguard Workerimport io
11*523fa7a6SAndroid Build Coastguard Workerimport logging
12*523fa7a6SAndroid Build Coastguard Workerfrom typing import Any, Dict, List, Optional, Sequence, Set, TextIO, Tuple, Union
13*523fa7a6SAndroid Build Coastguard Worker
14*523fa7a6SAndroid Build Coastguard Workerimport torch
15*523fa7a6SAndroid Build Coastguard Workerimport torch._export
16*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir._serialize import _serialize_pte_binary
17*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir._serialize._cord import Cord
18*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir._warnings import experimental
19*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.backend_api import to_backend
20*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.partitioner import Partitioner
21*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig
22*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.emit import emit_program, EmitterOutput
23*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.emit._emitter import _DelegateDebugIdentifierMap
24*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.error import ExportError
25*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.graph_module import get_control_flow_submodules
26*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.pass_base import PassBase
27*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.pass_manager import PassType
28*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes import (
29*523fa7a6SAndroid Build Coastguard Worker    base_post_op_replace_passes,
30*523fa7a6SAndroid Build Coastguard Worker    base_pre_op_replace_passes,
31*523fa7a6SAndroid Build Coastguard Worker    dead_code_elimination_pass,
32*523fa7a6SAndroid Build Coastguard Worker    EdgeToBackendOpsPass,
33*523fa7a6SAndroid Build Coastguard Worker    MemoryFormatOpsPass,
34*523fa7a6SAndroid Build Coastguard Worker    OpReplacePass,
35*523fa7a6SAndroid Build Coastguard Worker)
36*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.insert_write_back_for_buffers_pass import (
37*523fa7a6SAndroid Build Coastguard Worker    insert_write_back_for_buffers_pass,
38*523fa7a6SAndroid Build Coastguard Worker)
39*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.normalize_view_copy_base_pass import (
40*523fa7a6SAndroid Build Coastguard Worker    NormalizeViewCopyBasePass,
41*523fa7a6SAndroid Build Coastguard Worker)
42*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.remove_graph_asserts_pass import RemoveGraphAssertsPass
43*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.remove_mixed_type_operators import RemoveMixedTypeOperators
44*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.replace_aten_with_edge_pass import aten_to_edge
45*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.replace_view_copy_with_view_pass import (
46*523fa7a6SAndroid Build Coastguard Worker    ReplaceViewCopyWithViewPass,
47*523fa7a6SAndroid Build Coastguard Worker)
48*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.spec_prop_pass import SpecPropPass
49*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.weights_to_outputs_pass import weights_to_outputs_pass
50*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.print_program import pretty_print, print_program
51*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.schema import Program
52*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tracer import _default_decomposition_table
53*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.verification.verifier import (
54*523fa7a6SAndroid Build Coastguard Worker    EXIRATenDialectVerifier,
55*523fa7a6SAndroid Build Coastguard Worker    EXIREdgeDialectVerifier,
56*523fa7a6SAndroid Build Coastguard Worker    get_aten_verifier,
57*523fa7a6SAndroid Build Coastguard Worker)
58*523fa7a6SAndroid Build Coastguard Workerfrom torch._export.passes import ReplaceViewOpsWithViewCopyOpsPass
59*523fa7a6SAndroid Build Coastguard Workerfrom torch.export import ExportedProgram
60*523fa7a6SAndroid Build Coastguard Workerfrom torch.export._remove_auto_functionalized_pass import (
61*523fa7a6SAndroid Build Coastguard Worker    unsafe_remove_auto_functionalized_pass,
62*523fa7a6SAndroid Build Coastguard Worker)
63*523fa7a6SAndroid Build Coastguard Workerfrom torch.export.exported_program import (
64*523fa7a6SAndroid Build Coastguard Worker    ConstantArgument,
65*523fa7a6SAndroid Build Coastguard Worker    ExportGraphSignature,
66*523fa7a6SAndroid Build Coastguard Worker    InputKind,
67*523fa7a6SAndroid Build Coastguard Worker    InputSpec,
68*523fa7a6SAndroid Build Coastguard Worker    OutputSpec,
69*523fa7a6SAndroid Build Coastguard Worker    TensorArgument,
70*523fa7a6SAndroid Build Coastguard Worker)
71*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx import _pytree as fx_pytree
72*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx._compatibility import compatibility
73*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx.passes.infra.pass_manager import PassManager
74*523fa7a6SAndroid Build Coastguard Workerfrom torch.utils import _pytree as pytree
75*523fa7a6SAndroid Build Coastguard Worker
76*523fa7a6SAndroid Build Coastguard WorkerVal = Any
77*523fa7a6SAndroid Build Coastguard Worker
78*523fa7a6SAndroid Build Coastguard Workerfrom torch.library import Library
79*523fa7a6SAndroid Build Coastguard Worker
80*523fa7a6SAndroid Build Coastguard Worker# This is the reserved namespace that is used to register ops to that will
81*523fa7a6SAndroid Build Coastguard Worker# be prevented from being decomposed during to_edge_transform_and_lower.
82*523fa7a6SAndroid Build Coastguard Workeredge_no_decomp_namespace = "EDGE_DO_NOT_DECOMP"
83*523fa7a6SAndroid Build Coastguard Workerlib = Library(edge_no_decomp_namespace, "DEF")
84*523fa7a6SAndroid Build Coastguard Worker# Map from aten ops to the transformed ops registered in the edge_no_decomp_namespace.
85*523fa7a6SAndroid Build Coastguard Workeraten_op_to_transform_op = {}
86*523fa7a6SAndroid Build Coastguard Worker# Map from the transformed ops registered in the edge_no_decomp_namespace to aten ops.
87*523fa7a6SAndroid Build Coastguard Workertransform_op_to_aten_op = {}
88*523fa7a6SAndroid Build Coastguard Worker
89*523fa7a6SAndroid Build Coastguard Worker
90*523fa7a6SAndroid Build Coastguard Workerdef _get_updated_range_constraints(gm):
91*523fa7a6SAndroid Build Coastguard Worker    def get_shape_env(gm):
92*523fa7a6SAndroid Build Coastguard Worker        vals = [
93*523fa7a6SAndroid Build Coastguard Worker            node.meta["val"]
94*523fa7a6SAndroid Build Coastguard Worker            for node in gm.graph.nodes
95*523fa7a6SAndroid Build Coastguard Worker            if node.meta.get("val", None) is not None
96*523fa7a6SAndroid Build Coastguard Worker        ]
97*523fa7a6SAndroid Build Coastguard Worker        from torch._guards import detect_fake_mode  # type: ignore[21]
98*523fa7a6SAndroid Build Coastguard Worker
99*523fa7a6SAndroid Build Coastguard Worker        fake_mode = detect_fake_mode(vals)
100*523fa7a6SAndroid Build Coastguard Worker        if fake_mode is not None:
101*523fa7a6SAndroid Build Coastguard Worker            return fake_mode.shape_env
102*523fa7a6SAndroid Build Coastguard Worker        for v in vals:
103*523fa7a6SAndroid Build Coastguard Worker            if isinstance(v, torch.SymInt):
104*523fa7a6SAndroid Build Coastguard Worker                return v.node.shape_env
105*523fa7a6SAndroid Build Coastguard Worker
106*523fa7a6SAndroid Build Coastguard Worker    shape_env = get_shape_env(gm)
107*523fa7a6SAndroid Build Coastguard Worker    if shape_env is None:
108*523fa7a6SAndroid Build Coastguard Worker        return {}
109*523fa7a6SAndroid Build Coastguard Worker    range_constraints = {
110*523fa7a6SAndroid Build Coastguard Worker        k: v
111*523fa7a6SAndroid Build Coastguard Worker        for k, v in shape_env.var_to_range.items()
112*523fa7a6SAndroid Build Coastguard Worker        if k not in shape_env.replacements
113*523fa7a6SAndroid Build Coastguard Worker    }
114*523fa7a6SAndroid Build Coastguard Worker    # Only when we have an unbacked symint, and it's used as constructor inputs,
115*523fa7a6SAndroid Build Coastguard Worker    # runtime_var_to_range will make a difference compated to var_to_range.
116*523fa7a6SAndroid Build Coastguard Worker    # e.g. [2, oo) -> [0, oo)
117*523fa7a6SAndroid Build Coastguard Worker    for k, v in shape_env.var_to_range.items():
118*523fa7a6SAndroid Build Coastguard Worker        if k not in shape_env.replacements:
119*523fa7a6SAndroid Build Coastguard Worker            range_constraints[k] = v
120*523fa7a6SAndroid Build Coastguard Worker    return range_constraints
121*523fa7a6SAndroid Build Coastguard Worker
122*523fa7a6SAndroid Build Coastguard Worker
123*523fa7a6SAndroid Build Coastguard Workerdef _get_updated_graph_signature(
124*523fa7a6SAndroid Build Coastguard Worker    old_signature: ExportGraphSignature,
125*523fa7a6SAndroid Build Coastguard Worker    new_gm: torch.fx.GraphModule,
126*523fa7a6SAndroid Build Coastguard Worker) -> ExportGraphSignature:
127*523fa7a6SAndroid Build Coastguard Worker    """
128*523fa7a6SAndroid Build Coastguard Worker    Update the graph signature's user_input/user_outputs.
129*523fa7a6SAndroid Build Coastguard Worker    """
130*523fa7a6SAndroid Build Coastguard Worker    new_input_specs = []
131*523fa7a6SAndroid Build Coastguard Worker    i = 0
132*523fa7a6SAndroid Build Coastguard Worker    for node in new_gm.graph.nodes:
133*523fa7a6SAndroid Build Coastguard Worker        if node.op != "placeholder":
134*523fa7a6SAndroid Build Coastguard Worker            continue
135*523fa7a6SAndroid Build Coastguard Worker
136*523fa7a6SAndroid Build Coastguard Worker        assert i < len(
137*523fa7a6SAndroid Build Coastguard Worker            old_signature.input_specs
138*523fa7a6SAndroid Build Coastguard Worker        ), "Number of inputs changed after transformation"
139*523fa7a6SAndroid Build Coastguard Worker        old_input_spec = old_signature.input_specs[i]
140*523fa7a6SAndroid Build Coastguard Worker        arg = (
141*523fa7a6SAndroid Build Coastguard Worker            old_input_spec.arg
142*523fa7a6SAndroid Build Coastguard Worker            if isinstance(old_input_spec.arg, ConstantArgument)
143*523fa7a6SAndroid Build Coastguard Worker            # pyre-fixme[20]: Argument `class_fqn` expected.
144*523fa7a6SAndroid Build Coastguard Worker            else type(old_input_spec.arg)(node.name)
145*523fa7a6SAndroid Build Coastguard Worker        )
146*523fa7a6SAndroid Build Coastguard Worker        new_input_specs.append(
147*523fa7a6SAndroid Build Coastguard Worker            InputSpec(
148*523fa7a6SAndroid Build Coastguard Worker                old_input_spec.kind,
149*523fa7a6SAndroid Build Coastguard Worker                arg,
150*523fa7a6SAndroid Build Coastguard Worker                old_input_spec.target,
151*523fa7a6SAndroid Build Coastguard Worker                persistent=old_input_spec.persistent,
152*523fa7a6SAndroid Build Coastguard Worker            )
153*523fa7a6SAndroid Build Coastguard Worker        )
154*523fa7a6SAndroid Build Coastguard Worker        i += 1
155*523fa7a6SAndroid Build Coastguard Worker
156*523fa7a6SAndroid Build Coastguard Worker    output_node = list(new_gm.graph.nodes)[-1]
157*523fa7a6SAndroid Build Coastguard Worker    assert output_node.op == "output"
158*523fa7a6SAndroid Build Coastguard Worker
159*523fa7a6SAndroid Build Coastguard Worker    new_output_specs = []
160*523fa7a6SAndroid Build Coastguard Worker    for i, node in enumerate(output_node.args[0]):
161*523fa7a6SAndroid Build Coastguard Worker        assert i < len(
162*523fa7a6SAndroid Build Coastguard Worker            old_signature.output_specs
163*523fa7a6SAndroid Build Coastguard Worker        ), "Number of outputs changed after transformation"
164*523fa7a6SAndroid Build Coastguard Worker        old_output_spec = old_signature.output_specs[i]
165*523fa7a6SAndroid Build Coastguard Worker        arg = (
166*523fa7a6SAndroid Build Coastguard Worker            old_output_spec.arg
167*523fa7a6SAndroid Build Coastguard Worker            if isinstance(old_output_spec.arg, ConstantArgument)
168*523fa7a6SAndroid Build Coastguard Worker            # pyre-fixme[20]: Argument `class_fqn` expected.
169*523fa7a6SAndroid Build Coastguard Worker            else type(old_output_spec.arg)(node.name)
170*523fa7a6SAndroid Build Coastguard Worker        )
171*523fa7a6SAndroid Build Coastguard Worker        new_output_specs.append(
172*523fa7a6SAndroid Build Coastguard Worker            OutputSpec(old_output_spec.kind, arg, old_output_spec.target)
173*523fa7a6SAndroid Build Coastguard Worker        )
174*523fa7a6SAndroid Build Coastguard Worker
175*523fa7a6SAndroid Build Coastguard Worker    new_signature = ExportGraphSignature(
176*523fa7a6SAndroid Build Coastguard Worker        input_specs=new_input_specs, output_specs=new_output_specs
177*523fa7a6SAndroid Build Coastguard Worker    )
178*523fa7a6SAndroid Build Coastguard Worker    return new_signature
179*523fa7a6SAndroid Build Coastguard Worker
180*523fa7a6SAndroid Build Coastguard Worker
181*523fa7a6SAndroid Build Coastguard Workerdef _transform(self, *passes: PassType) -> "ExportedProgram":
182*523fa7a6SAndroid Build Coastguard Worker    pm = PassManager(list(passes))
183*523fa7a6SAndroid Build Coastguard Worker    res = pm(self.graph_module)
184*523fa7a6SAndroid Build Coastguard Worker    transformed_gm = res.graph_module if res is not None else self.graph_module
185*523fa7a6SAndroid Build Coastguard Worker    assert transformed_gm is not None
186*523fa7a6SAndroid Build Coastguard Worker
187*523fa7a6SAndroid Build Coastguard Worker    if transformed_gm is self.graph_module and not res.modified:
188*523fa7a6SAndroid Build Coastguard Worker        return self
189*523fa7a6SAndroid Build Coastguard Worker
190*523fa7a6SAndroid Build Coastguard Worker    transformed_ep = ExportedProgram(
191*523fa7a6SAndroid Build Coastguard Worker        root=transformed_gm,
192*523fa7a6SAndroid Build Coastguard Worker        graph=transformed_gm.graph,
193*523fa7a6SAndroid Build Coastguard Worker        graph_signature=_get_updated_graph_signature(
194*523fa7a6SAndroid Build Coastguard Worker            self.graph_signature, transformed_gm
195*523fa7a6SAndroid Build Coastguard Worker        ),
196*523fa7a6SAndroid Build Coastguard Worker        state_dict=self.state_dict,
197*523fa7a6SAndroid Build Coastguard Worker        range_constraints=_get_updated_range_constraints(transformed_gm),
198*523fa7a6SAndroid Build Coastguard Worker        module_call_graph=copy.deepcopy(self._module_call_graph),
199*523fa7a6SAndroid Build Coastguard Worker        example_inputs=self.example_inputs,
200*523fa7a6SAndroid Build Coastguard Worker        constants=self.constants,
201*523fa7a6SAndroid Build Coastguard Worker        verifiers=[self.verifier],
202*523fa7a6SAndroid Build Coastguard Worker    )
203*523fa7a6SAndroid Build Coastguard Worker    transformed_ep.graph_module.meta.update(self.graph_module.meta)
204*523fa7a6SAndroid Build Coastguard Worker    transformed_ep.graph_module.meta.update(res.graph_module.meta)
205*523fa7a6SAndroid Build Coastguard Worker    return transformed_ep
206*523fa7a6SAndroid Build Coastguard Worker
207*523fa7a6SAndroid Build Coastguard Worker
208*523fa7a6SAndroid Build Coastguard Workerdef _copy_module(new_prog, new_gm):
209*523fa7a6SAndroid Build Coastguard Worker    new_prog.meta.update(new_gm.meta)
210*523fa7a6SAndroid Build Coastguard Worker    new_prog.graph = new_gm.graph
211*523fa7a6SAndroid Build Coastguard Worker    submodules = [name for name, _ in new_prog.named_children()]
212*523fa7a6SAndroid Build Coastguard Worker    for name in submodules:
213*523fa7a6SAndroid Build Coastguard Worker        delattr(new_prog, name)
214*523fa7a6SAndroid Build Coastguard Worker    for name, mod in new_gm.named_children():
215*523fa7a6SAndroid Build Coastguard Worker        setattr(new_prog, name, mod)
216*523fa7a6SAndroid Build Coastguard Worker    for node in new_gm.graph.nodes:
217*523fa7a6SAndroid Build Coastguard Worker        if node.op == "get_attr":
218*523fa7a6SAndroid Build Coastguard Worker            t = getattr(new_gm, node.target, None)
219*523fa7a6SAndroid Build Coastguard Worker            if isinstance(t, torch.Tensor):
220*523fa7a6SAndroid Build Coastguard Worker                setattr(new_prog, node.target, t)
221*523fa7a6SAndroid Build Coastguard Worker
222*523fa7a6SAndroid Build Coastguard Worker
223*523fa7a6SAndroid Build Coastguard Workerdef lift_constant_tensor_pass(ep):
224*523fa7a6SAndroid Build Coastguard Worker    """
225*523fa7a6SAndroid Build Coastguard Worker    Takes an ExportedProgram and returns the ExportedProgram modified in-place,
226*523fa7a6SAndroid Build Coastguard Worker    with the constant tensors as buffers.
227*523fa7a6SAndroid Build Coastguard Worker    """
228*523fa7a6SAndroid Build Coastguard Worker    if len([node for node in ep.graph.nodes if node.op == "placeholder"]) == 0:
229*523fa7a6SAndroid Build Coastguard Worker        return ep
230*523fa7a6SAndroid Build Coastguard Worker
231*523fa7a6SAndroid Build Coastguard Worker    graph_signature = ep.graph_signature
232*523fa7a6SAndroid Build Coastguard Worker    buffers = list(graph_signature.buffers)
233*523fa7a6SAndroid Build Coastguard Worker
234*523fa7a6SAndroid Build Coastguard Worker    fake_mode = list(ep.graph.nodes)[0].meta["val"].fake_mode
235*523fa7a6SAndroid Build Coastguard Worker    first_user_input = None
236*523fa7a6SAndroid Build Coastguard Worker    lifted_constants = []
237*523fa7a6SAndroid Build Coastguard Worker    for node in ep.graph.nodes:
238*523fa7a6SAndroid Build Coastguard Worker        if node.op == "placeholder" and node.name in graph_signature.user_inputs:
239*523fa7a6SAndroid Build Coastguard Worker            first_user_input = node
240*523fa7a6SAndroid Build Coastguard Worker            break
241*523fa7a6SAndroid Build Coastguard Worker
242*523fa7a6SAndroid Build Coastguard Worker    for node in ep.graph.nodes:
243*523fa7a6SAndroid Build Coastguard Worker        if node.op == "get_attr":
244*523fa7a6SAndroid Build Coastguard Worker            constant_tensor = getattr(ep.graph_module, node.target)
245*523fa7a6SAndroid Build Coastguard Worker            if not isinstance(constant_tensor, torch.Tensor):
246*523fa7a6SAndroid Build Coastguard Worker                continue
247*523fa7a6SAndroid Build Coastguard Worker
248*523fa7a6SAndroid Build Coastguard Worker            constant_tensor_fqn = f"_lifted_tensor_constant{len(buffers)}"
249*523fa7a6SAndroid Build Coastguard Worker
250*523fa7a6SAndroid Build Coastguard Worker            with ep.graph.inserting_before(first_user_input):
251*523fa7a6SAndroid Build Coastguard Worker                # Insert the constant node before the first user input
252*523fa7a6SAndroid Build Coastguard Worker                const_placeholder_node = ep.graph.placeholder(constant_tensor_fqn)
253*523fa7a6SAndroid Build Coastguard Worker                for k, v in node.meta.items():
254*523fa7a6SAndroid Build Coastguard Worker                    const_placeholder_node.meta[k] = v
255*523fa7a6SAndroid Build Coastguard Worker                if fake_mode is not None:
256*523fa7a6SAndroid Build Coastguard Worker                    const_placeholder_node.meta["val"] = fake_mode.from_tensor(
257*523fa7a6SAndroid Build Coastguard Worker                        constant_tensor, static_shapes=True
258*523fa7a6SAndroid Build Coastguard Worker                    )
259*523fa7a6SAndroid Build Coastguard Worker                else:
260*523fa7a6SAndroid Build Coastguard Worker                    const_placeholder_node.meta["val"] = constant_tensor
261*523fa7a6SAndroid Build Coastguard Worker                const_placeholder_node.meta["val"].constant = constant_tensor
262*523fa7a6SAndroid Build Coastguard Worker                node.replace_all_uses_with(const_placeholder_node)
263*523fa7a6SAndroid Build Coastguard Worker                ep.graph.erase_node(node)
264*523fa7a6SAndroid Build Coastguard Worker
265*523fa7a6SAndroid Build Coastguard Worker                # Add the constant as a buffer to the graph signature
266*523fa7a6SAndroid Build Coastguard Worker                lifted_constants.append(
267*523fa7a6SAndroid Build Coastguard Worker                    InputSpec(
268*523fa7a6SAndroid Build Coastguard Worker                        kind=InputKind.BUFFER,
269*523fa7a6SAndroid Build Coastguard Worker                        arg=TensorArgument(name=const_placeholder_node.name),
270*523fa7a6SAndroid Build Coastguard Worker                        target=constant_tensor_fqn,
271*523fa7a6SAndroid Build Coastguard Worker                        persistent=True,
272*523fa7a6SAndroid Build Coastguard Worker                    )
273*523fa7a6SAndroid Build Coastguard Worker                )
274*523fa7a6SAndroid Build Coastguard Worker                buffers.append(constant_tensor_fqn)
275*523fa7a6SAndroid Build Coastguard Worker                ep.state_dict[constant_tensor_fqn] = constant_tensor
276*523fa7a6SAndroid Build Coastguard Worker
277*523fa7a6SAndroid Build Coastguard Worker    new_input_specs = []
278*523fa7a6SAndroid Build Coastguard Worker    for s in graph_signature.input_specs:
279*523fa7a6SAndroid Build Coastguard Worker        if s.kind == InputKind.USER_INPUT and len(lifted_constants) > 0:
280*523fa7a6SAndroid Build Coastguard Worker            new_input_specs.extend(lifted_constants)
281*523fa7a6SAndroid Build Coastguard Worker            lifted_constants.clear()
282*523fa7a6SAndroid Build Coastguard Worker        new_input_specs.append(s)
283*523fa7a6SAndroid Build Coastguard Worker    ep.graph_signature.input_specs = new_input_specs
284*523fa7a6SAndroid Build Coastguard Worker    ep.graph_module.recompile()
285*523fa7a6SAndroid Build Coastguard Worker    return ep
286*523fa7a6SAndroid Build Coastguard Worker
287*523fa7a6SAndroid Build Coastguard Worker
288*523fa7a6SAndroid Build Coastguard Worker# Stub to ease migration from `transform` to private `_transform`
289*523fa7a6SAndroid Build Coastguard Workerdef transform_exported_program(ep, *passes: PassType) -> ExportedProgram:
290*523fa7a6SAndroid Build Coastguard Worker    if hasattr(ep, "_transform"):
291*523fa7a6SAndroid Build Coastguard Worker        return ep._transform(*passes)
292*523fa7a6SAndroid Build Coastguard Worker    else:
293*523fa7a6SAndroid Build Coastguard Worker        return ep.transform(*passes)
294*523fa7a6SAndroid Build Coastguard Worker
295*523fa7a6SAndroid Build Coastguard Worker
296*523fa7a6SAndroid Build Coastguard Workerclass HackedUpExportedProgramDONOTUSE(ExportedProgram):
297*523fa7a6SAndroid Build Coastguard Worker    def __init__(
298*523fa7a6SAndroid Build Coastguard Worker        self,
299*523fa7a6SAndroid Build Coastguard Worker        root,
300*523fa7a6SAndroid Build Coastguard Worker        graph,
301*523fa7a6SAndroid Build Coastguard Worker        graph_signature,
302*523fa7a6SAndroid Build Coastguard Worker        call_spec,
303*523fa7a6SAndroid Build Coastguard Worker        state_dict,
304*523fa7a6SAndroid Build Coastguard Worker        range_constraints,
305*523fa7a6SAndroid Build Coastguard Worker        module_call_graph,
306*523fa7a6SAndroid Build Coastguard Worker        example_inputs,
307*523fa7a6SAndroid Build Coastguard Worker        verifier,
308*523fa7a6SAndroid Build Coastguard Worker    ):
309*523fa7a6SAndroid Build Coastguard Worker        super().__init__(
310*523fa7a6SAndroid Build Coastguard Worker            root=root,
311*523fa7a6SAndroid Build Coastguard Worker            graph=graph,
312*523fa7a6SAndroid Build Coastguard Worker            graph_signature=graph_signature,
313*523fa7a6SAndroid Build Coastguard Worker            state_dict=state_dict,
314*523fa7a6SAndroid Build Coastguard Worker            range_constraints=range_constraints,
315*523fa7a6SAndroid Build Coastguard Worker            module_call_graph=module_call_graph,
316*523fa7a6SAndroid Build Coastguard Worker            example_inputs=example_inputs,
317*523fa7a6SAndroid Build Coastguard Worker            verifier=verifier,
318*523fa7a6SAndroid Build Coastguard Worker        )
319*523fa7a6SAndroid Build Coastguard Worker
320*523fa7a6SAndroid Build Coastguard Worker    def __call__(self, *args: Any, **kwargs: Any) -> Any:
321*523fa7a6SAndroid Build Coastguard Worker        import torch._export.error as error
322*523fa7a6SAndroid Build Coastguard Worker
323*523fa7a6SAndroid Build Coastguard Worker        if self.call_spec.in_spec is not None:
324*523fa7a6SAndroid Build Coastguard Worker            user_args = args
325*523fa7a6SAndroid Build Coastguard Worker            try:
326*523fa7a6SAndroid Build Coastguard Worker                args = fx_pytree.tree_flatten_spec(user_args, self.call_spec.in_spec)  # type: ignore[assignment]
327*523fa7a6SAndroid Build Coastguard Worker            except Exception:
328*523fa7a6SAndroid Build Coastguard Worker                _, received_spec = pytree.tree_flatten(user_args)
329*523fa7a6SAndroid Build Coastguard Worker                raise error.InternalError(
330*523fa7a6SAndroid Build Coastguard Worker                    "Trying to flatten user inputs with exported input tree spec: \n"
331*523fa7a6SAndroid Build Coastguard Worker                    f"{self.call_spec.in_spec}\n"
332*523fa7a6SAndroid Build Coastguard Worker                    "but actually got inputs with tree spec of: \n"
333*523fa7a6SAndroid Build Coastguard Worker                    f"{received_spec}"
334*523fa7a6SAndroid Build Coastguard Worker                )
335*523fa7a6SAndroid Build Coastguard Worker
336*523fa7a6SAndroid Build Coastguard Worker        ordered_params = tuple(
337*523fa7a6SAndroid Build Coastguard Worker            self.state_dict[name] for name in self.graph_signature.parameters
338*523fa7a6SAndroid Build Coastguard Worker        )
339*523fa7a6SAndroid Build Coastguard Worker        ordered_buffers = tuple(
340*523fa7a6SAndroid Build Coastguard Worker            self.state_dict[name] for name in self.graph_signature.buffers
341*523fa7a6SAndroid Build Coastguard Worker        )
342*523fa7a6SAndroid Build Coastguard Worker
343*523fa7a6SAndroid Build Coastguard Worker        with torch.no_grad():
344*523fa7a6SAndroid Build Coastguard Worker            # NOTE: calling convention is first params, then buffers, then args as user supplied them.
345*523fa7a6SAndroid Build Coastguard Worker            # See: torch/_functorch/aot_autograd.py#L1034
346*523fa7a6SAndroid Build Coastguard Worker            res = torch.fx.Interpreter(self.graph_module).run(
347*523fa7a6SAndroid Build Coastguard Worker                *ordered_params, *ordered_buffers, *args, enable_io_processing=False
348*523fa7a6SAndroid Build Coastguard Worker            )
349*523fa7a6SAndroid Build Coastguard Worker
350*523fa7a6SAndroid Build Coastguard Worker        if self.call_spec.out_spec is not None:
351*523fa7a6SAndroid Build Coastguard Worker            mutation = self.graph_signature.buffers_to_mutate
352*523fa7a6SAndroid Build Coastguard Worker            num_mutated = len(mutation)
353*523fa7a6SAndroid Build Coastguard Worker            mutated_buffers = res[:num_mutated]
354*523fa7a6SAndroid Build Coastguard Worker
355*523fa7a6SAndroid Build Coastguard Worker            # Exclude dependency token from final result.
356*523fa7a6SAndroid Build Coastguard Worker            assertion_dep_token = self.graph_signature.assertion_dep_token
357*523fa7a6SAndroid Build Coastguard Worker            if assertion_dep_token is not None:
358*523fa7a6SAndroid Build Coastguard Worker                assertion_dep_token_index = list(assertion_dep_token.keys())[0]
359*523fa7a6SAndroid Build Coastguard Worker                res = res[:assertion_dep_token_index]
360*523fa7a6SAndroid Build Coastguard Worker
361*523fa7a6SAndroid Build Coastguard Worker            res = res[num_mutated:]
362*523fa7a6SAndroid Build Coastguard Worker            try:
363*523fa7a6SAndroid Build Coastguard Worker                res = pytree.tree_unflatten(res, self.call_spec.out_spec)
364*523fa7a6SAndroid Build Coastguard Worker            except Exception:
365*523fa7a6SAndroid Build Coastguard Worker                _, received_spec = pytree.tree_flatten(res)
366*523fa7a6SAndroid Build Coastguard Worker                raise error.InternalError(
367*523fa7a6SAndroid Build Coastguard Worker                    "Trying to flatten user outputs with exported output tree spec: \n"
368*523fa7a6SAndroid Build Coastguard Worker                    f"{self.call_spec.out_spec}\n"
369*523fa7a6SAndroid Build Coastguard Worker                    "but actually got outputs with tree spec of: \n"
370*523fa7a6SAndroid Build Coastguard Worker                    f"{received_spec}"
371*523fa7a6SAndroid Build Coastguard Worker                )
372*523fa7a6SAndroid Build Coastguard Worker            finally:
373*523fa7a6SAndroid Build Coastguard Worker                ix = 0
374*523fa7a6SAndroid Build Coastguard Worker                for buffer in self.graph_signature.buffers_to_mutate.values():
375*523fa7a6SAndroid Build Coastguard Worker                    self.state_dict[buffer] = mutated_buffers[ix]
376*523fa7a6SAndroid Build Coastguard Worker                    ix += 1
377*523fa7a6SAndroid Build Coastguard Worker        return res
378*523fa7a6SAndroid Build Coastguard Worker
379*523fa7a6SAndroid Build Coastguard Worker
380*523fa7a6SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=False)
381*523fa7a6SAndroid Build Coastguard Workerclass ExirExportedProgram:
382*523fa7a6SAndroid Build Coastguard Worker    def __init__(
383*523fa7a6SAndroid Build Coastguard Worker        self,
384*523fa7a6SAndroid Build Coastguard Worker        exported_program: ExportedProgram,
385*523fa7a6SAndroid Build Coastguard Worker        after_to_edge_passes: bool,
386*523fa7a6SAndroid Build Coastguard Worker    ):
387*523fa7a6SAndroid Build Coastguard Worker        self.exported_program = exported_program
388*523fa7a6SAndroid Build Coastguard Worker
389*523fa7a6SAndroid Build Coastguard Worker        # Add a flag to denote whehter to_edge is called on this program
390*523fa7a6SAndroid Build Coastguard Worker        # to detect misusage of directly calling to_executorch without to_edge
391*523fa7a6SAndroid Build Coastguard Worker        self.after_to_edge_passes = after_to_edge_passes
392*523fa7a6SAndroid Build Coastguard Worker
393*523fa7a6SAndroid Build Coastguard Worker    def transform(self, *passes: PassType) -> "ExirExportedProgram":
394*523fa7a6SAndroid Build Coastguard Worker        self.exported_program = _transform(self.exported_program, *passes)
395*523fa7a6SAndroid Build Coastguard Worker        return self
396*523fa7a6SAndroid Build Coastguard Worker
397*523fa7a6SAndroid Build Coastguard Worker    def __call__(self, *args: Any) -> Any:
398*523fa7a6SAndroid Build Coastguard Worker        return self.exported_program.module()(*args)
399*523fa7a6SAndroid Build Coastguard Worker
400*523fa7a6SAndroid Build Coastguard Worker    # TODO(ycao): Change this to a composable function.
401*523fa7a6SAndroid Build Coastguard Worker    def to_edge(
402*523fa7a6SAndroid Build Coastguard Worker        self, config: Optional[EdgeCompileConfig] = None
403*523fa7a6SAndroid Build Coastguard Worker    ) -> "ExirExportedProgram":
404*523fa7a6SAndroid Build Coastguard Worker        config = config or EdgeCompileConfig()
405*523fa7a6SAndroid Build Coastguard Worker        assert isinstance(
406*523fa7a6SAndroid Build Coastguard Worker            self.exported_program.graph_module, torch.fx.GraphModule
407*523fa7a6SAndroid Build Coastguard Worker        ), f"type is instead: {type(self.exported_program.graph_module).__name__}"
408*523fa7a6SAndroid Build Coastguard Worker
409*523fa7a6SAndroid Build Coastguard Worker        return _to_edge(self, config)
410*523fa7a6SAndroid Build Coastguard Worker
411*523fa7a6SAndroid Build Coastguard Worker    def dump(self) -> None:
412*523fa7a6SAndroid Build Coastguard Worker        print(self.exported_program.graph_module.graph)
413*523fa7a6SAndroid Build Coastguard Worker
414*523fa7a6SAndroid Build Coastguard Worker    def to_executorch(
415*523fa7a6SAndroid Build Coastguard Worker        self,
416*523fa7a6SAndroid Build Coastguard Worker        config: Optional[ExecutorchBackendConfig] = None,
417*523fa7a6SAndroid Build Coastguard Worker    ) -> "ExecutorchProgram":
418*523fa7a6SAndroid Build Coastguard Worker        if not self.after_to_edge_passes:
419*523fa7a6SAndroid Build Coastguard Worker            raise RuntimeError("Must run to_edge before to_executorch.")
420*523fa7a6SAndroid Build Coastguard Worker        config = config or ExecutorchBackendConfig()
421*523fa7a6SAndroid Build Coastguard Worker        new_gm = self.exported_program.graph_module
422*523fa7a6SAndroid Build Coastguard Worker        for p in edge_to_executorch_passes(config):
423*523fa7a6SAndroid Build Coastguard Worker            new_gm_res = p(new_gm)
424*523fa7a6SAndroid Build Coastguard Worker            assert new_gm_res is not None
425*523fa7a6SAndroid Build Coastguard Worker            new_gm = new_gm_res.graph_module
426*523fa7a6SAndroid Build Coastguard Worker
427*523fa7a6SAndroid Build Coastguard Worker        # This is tech debt on tech debt. memory planning pass inherits from some pass infra for GMs.
428*523fa7a6SAndroid Build Coastguard Worker        # This isnt enough info now so i cant use call I have to use some new function 'run'.
429*523fa7a6SAndroid Build Coastguard Worker        # Existing user passes dont use run so Im just cheating here because they dont need to work on mutable buffers yet.
430*523fa7a6SAndroid Build Coastguard Worker        # After exir.capture is gone I will clean up the memory planning infra to be consistent.
431*523fa7a6SAndroid Build Coastguard Worker        # Frankly all of exir has big code quality issues because of the migrations that need to be addressed.
432*523fa7a6SAndroid Build Coastguard Worker        new_gm_res = config.memory_planning_pass(new_gm)  # pyre-ignore[29]
433*523fa7a6SAndroid Build Coastguard Worker        assert new_gm_res is not None
434*523fa7a6SAndroid Build Coastguard Worker        new_gm = new_gm_res.graph_module
435*523fa7a6SAndroid Build Coastguard Worker        new_prog = ExirExportedProgram(
436*523fa7a6SAndroid Build Coastguard Worker            copy.deepcopy(self.exported_program), self.after_to_edge_passes
437*523fa7a6SAndroid Build Coastguard Worker        )
438*523fa7a6SAndroid Build Coastguard Worker        _copy_module(new_prog.exported_program.graph_module, new_gm)
439*523fa7a6SAndroid Build Coastguard Worker        executorch_prog = ExecutorchProgram(
440*523fa7a6SAndroid Build Coastguard Worker            new_prog,
441*523fa7a6SAndroid Build Coastguard Worker            emit_stacktrace=config.emit_stacktrace,
442*523fa7a6SAndroid Build Coastguard Worker            extract_delegate_segments=config.extract_delegate_segments,
443*523fa7a6SAndroid Build Coastguard Worker            segment_alignment=config.segment_alignment,
444*523fa7a6SAndroid Build Coastguard Worker            constant_tensor_alignment=config.constant_tensor_alignment,
445*523fa7a6SAndroid Build Coastguard Worker            delegate_alignment=config.delegate_alignment,
446*523fa7a6SAndroid Build Coastguard Worker        )
447*523fa7a6SAndroid Build Coastguard Worker        executorch_prog.graph_module.meta.update(new_gm.meta)
448*523fa7a6SAndroid Build Coastguard Worker        executorch_prog.graph_module.meta.update(
449*523fa7a6SAndroid Build Coastguard Worker            self.exported_program.graph_module.meta
450*523fa7a6SAndroid Build Coastguard Worker        )
451*523fa7a6SAndroid Build Coastguard Worker        return executorch_prog
452*523fa7a6SAndroid Build Coastguard Worker
453*523fa7a6SAndroid Build Coastguard Worker    def __deepcopy__(
454*523fa7a6SAndroid Build Coastguard Worker        self, memo: Optional[Dict[int, Any]] = None
455*523fa7a6SAndroid Build Coastguard Worker    ) -> "ExirExportedProgram":
456*523fa7a6SAndroid Build Coastguard Worker        new_eep = ExirExportedProgram(
457*523fa7a6SAndroid Build Coastguard Worker            copy.deepcopy(self.exported_program, memo),
458*523fa7a6SAndroid Build Coastguard Worker            self.after_to_edge_passes,
459*523fa7a6SAndroid Build Coastguard Worker        )
460*523fa7a6SAndroid Build Coastguard Worker        return new_eep
461*523fa7a6SAndroid Build Coastguard Worker
462*523fa7a6SAndroid Build Coastguard Worker
463*523fa7a6SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=False)
464*523fa7a6SAndroid Build Coastguard Workerclass ExecutorchProgram:
465*523fa7a6SAndroid Build Coastguard Worker    def __init__(
466*523fa7a6SAndroid Build Coastguard Worker        self,
467*523fa7a6SAndroid Build Coastguard Worker        exir_exported_program: ExirExportedProgram,
468*523fa7a6SAndroid Build Coastguard Worker        emit_stacktrace: bool,
469*523fa7a6SAndroid Build Coastguard Worker        extract_delegate_segments: bool,
470*523fa7a6SAndroid Build Coastguard Worker        segment_alignment: int,
471*523fa7a6SAndroid Build Coastguard Worker        constant_tensor_alignment: Optional[int] = None,
472*523fa7a6SAndroid Build Coastguard Worker        delegate_alignment: Optional[int] = None,
473*523fa7a6SAndroid Build Coastguard Worker    ) -> None:
474*523fa7a6SAndroid Build Coastguard Worker        if not exir_exported_program.after_to_edge_passes:
475*523fa7a6SAndroid Build Coastguard Worker            raise RuntimeError(
476*523fa7a6SAndroid Build Coastguard Worker                "Need to call prog.to_edge prior to constructing ExecutorchProgram."
477*523fa7a6SAndroid Build Coastguard Worker            )
478*523fa7a6SAndroid Build Coastguard Worker        self.exported_program = exir_exported_program.exported_program
479*523fa7a6SAndroid Build Coastguard Worker        self._pte_data: Optional[Cord] = None
480*523fa7a6SAndroid Build Coastguard Worker        self._buffer: Optional[bytes] = None
481*523fa7a6SAndroid Build Coastguard Worker        self._emitter_output: Optional[EmitterOutput] = None
482*523fa7a6SAndroid Build Coastguard Worker        self._emit_stacktrace: bool = emit_stacktrace
483*523fa7a6SAndroid Build Coastguard Worker        self._extract_delegate_segments: bool = extract_delegate_segments
484*523fa7a6SAndroid Build Coastguard Worker        self._segment_alignment: int = segment_alignment
485*523fa7a6SAndroid Build Coastguard Worker        self._constant_tensor_alignment: Optional[int] = constant_tensor_alignment
486*523fa7a6SAndroid Build Coastguard Worker        self._delegate_alignment: Optional[int] = delegate_alignment
487*523fa7a6SAndroid Build Coastguard Worker
488*523fa7a6SAndroid Build Coastguard Worker    def _get_pte_data(self) -> Cord:
489*523fa7a6SAndroid Build Coastguard Worker        if self._pte_data is None:
490*523fa7a6SAndroid Build Coastguard Worker            self._pte_data = _serialize_pte_binary(
491*523fa7a6SAndroid Build Coastguard Worker                program=self.program,
492*523fa7a6SAndroid Build Coastguard Worker                extract_delegate_segments=self._extract_delegate_segments,
493*523fa7a6SAndroid Build Coastguard Worker                segment_alignment=self._segment_alignment,
494*523fa7a6SAndroid Build Coastguard Worker                constant_tensor_alignment=self._constant_tensor_alignment,
495*523fa7a6SAndroid Build Coastguard Worker                delegate_alignment=self._delegate_alignment,
496*523fa7a6SAndroid Build Coastguard Worker            )
497*523fa7a6SAndroid Build Coastguard Worker        return self._pte_data
498*523fa7a6SAndroid Build Coastguard Worker
499*523fa7a6SAndroid Build Coastguard Worker    @property
500*523fa7a6SAndroid Build Coastguard Worker    def buffer(self) -> bytes:
501*523fa7a6SAndroid Build Coastguard Worker        """Returns the serialized ExecuTorch binary as a byte string.
502*523fa7a6SAndroid Build Coastguard Worker
503*523fa7a6SAndroid Build Coastguard Worker        Note that the call to `buffer` may allocate a very large amount of
504*523fa7a6SAndroid Build Coastguard Worker        contiguous memory, depending on the model size. If writing to a file,
505*523fa7a6SAndroid Build Coastguard Worker        use `write_to_file` which won't incur additional copies.
506*523fa7a6SAndroid Build Coastguard Worker        """
507*523fa7a6SAndroid Build Coastguard Worker        # TODO(T181494963): update pybinding to remove buffer cache, which can consume large
508*523fa7a6SAndroid Build Coastguard Worker        # amounts of memory longer than necessary.
509*523fa7a6SAndroid Build Coastguard Worker        if self._buffer is None:
510*523fa7a6SAndroid Build Coastguard Worker            self._buffer = bytes(self._get_pte_data())
511*523fa7a6SAndroid Build Coastguard Worker        return self._buffer
512*523fa7a6SAndroid Build Coastguard Worker
513*523fa7a6SAndroid Build Coastguard Worker    @property
514*523fa7a6SAndroid Build Coastguard Worker    def program(self) -> Program:
515*523fa7a6SAndroid Build Coastguard Worker        if self._emitter_output is None:
516*523fa7a6SAndroid Build Coastguard Worker            self._emitter_output = emit_program(
517*523fa7a6SAndroid Build Coastguard Worker                self.exported_program, self._emit_stacktrace
518*523fa7a6SAndroid Build Coastguard Worker            )
519*523fa7a6SAndroid Build Coastguard Worker        return self._emitter_output.program
520*523fa7a6SAndroid Build Coastguard Worker
521*523fa7a6SAndroid Build Coastguard Worker    @property
522*523fa7a6SAndroid Build Coastguard Worker    def debug_handle_map(self) -> Dict[int, Union[int, List[int]]]:
523*523fa7a6SAndroid Build Coastguard Worker        if self._emitter_output:
524*523fa7a6SAndroid Build Coastguard Worker            return self._emitter_output.debug_handle_map
525*523fa7a6SAndroid Build Coastguard Worker        return {}
526*523fa7a6SAndroid Build Coastguard Worker
527*523fa7a6SAndroid Build Coastguard Worker    @property
528*523fa7a6SAndroid Build Coastguard Worker    def delegate_map(
529*523fa7a6SAndroid Build Coastguard Worker        self,
530*523fa7a6SAndroid Build Coastguard Worker    ) -> Dict[str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]]:
531*523fa7a6SAndroid Build Coastguard Worker        if self._emitter_output:
532*523fa7a6SAndroid Build Coastguard Worker            return self._emitter_output.method_to_delegate_debug_id_map
533*523fa7a6SAndroid Build Coastguard Worker        return {}
534*523fa7a6SAndroid Build Coastguard Worker
535*523fa7a6SAndroid Build Coastguard Worker    @property
536*523fa7a6SAndroid Build Coastguard Worker    def graph_module(self) -> torch.fx.GraphModule:
537*523fa7a6SAndroid Build Coastguard Worker        return self.exported_program.graph_module
538*523fa7a6SAndroid Build Coastguard Worker
539*523fa7a6SAndroid Build Coastguard Worker    # TODO (zhxchen17) Change this to property.
540*523fa7a6SAndroid Build Coastguard Worker    def dump_graph_module(self) -> torch.fx.GraphModule:
541*523fa7a6SAndroid Build Coastguard Worker        return self.exported_program.graph_module
542*523fa7a6SAndroid Build Coastguard Worker
543*523fa7a6SAndroid Build Coastguard Worker    def dump_exported_program(self) -> ExportedProgram:
544*523fa7a6SAndroid Build Coastguard Worker        return self.exported_program
545*523fa7a6SAndroid Build Coastguard Worker
546*523fa7a6SAndroid Build Coastguard Worker    def write_to_file(self, open_file: io.BufferedIOBase) -> None:
547*523fa7a6SAndroid Build Coastguard Worker        """
548*523fa7a6SAndroid Build Coastguard Worker        Writes the serialized ExecuTorch binary to the file at `open_file`. Prefer to use this over
549*523fa7a6SAndroid Build Coastguard Worker        `buffer`, as it writes to file without copying into a contiguous block of memory first,
550*523fa7a6SAndroid Build Coastguard Worker        reducing the peak memory usage.
551*523fa7a6SAndroid Build Coastguard Worker        """
552*523fa7a6SAndroid Build Coastguard Worker        self._get_pte_data().write_to_file(open_file)
553*523fa7a6SAndroid Build Coastguard Worker
554*523fa7a6SAndroid Build Coastguard Worker
555*523fa7a6SAndroid Build Coastguard Workerdef _get_aten_to_edge_passes(config: EdgeCompileConfig):
556*523fa7a6SAndroid Build Coastguard Worker    # TODO: the last two passes for aten_to_edge need to be eliminated_dead_code -> debug_handle_generator. After enable
557*523fa7a6SAndroid Build Coastguard Worker    # use_edge_op it can be moved to aten_to_edge_passes before eliminated_dead_code pass. Also ExportPass doesn't play
558*523fa7a6SAndroid Build Coastguard Worker    # well with node.meta, meaning after some passes permuting operators, we may lose some information in node.meta.
559*523fa7a6SAndroid Build Coastguard Worker    # It might be regenerated in SpecPropPass so it may not be visiable. However debug handle will be lost.
560*523fa7a6SAndroid Build Coastguard Worker
561*523fa7a6SAndroid Build Coastguard Worker    pre_op_replace_passes = base_pre_op_replace_passes + (
562*523fa7a6SAndroid Build Coastguard Worker        [] if config._skip_type_promotion else [RemoveMixedTypeOperators()]
563*523fa7a6SAndroid Build Coastguard Worker    )
564*523fa7a6SAndroid Build Coastguard Worker
565*523fa7a6SAndroid Build Coastguard Worker    post_op_replace_passes = base_post_op_replace_passes
566*523fa7a6SAndroid Build Coastguard Worker
567*523fa7a6SAndroid Build Coastguard Worker    return pre_op_replace_passes, post_op_replace_passes
568*523fa7a6SAndroid Build Coastguard Worker
569*523fa7a6SAndroid Build Coastguard Worker
570*523fa7a6SAndroid Build Coastguard Workerdef _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram":
571*523fa7a6SAndroid Build Coastguard Worker    if config._check_ir_validity:
572*523fa7a6SAndroid Build Coastguard Worker        try:
573*523fa7a6SAndroid Build Coastguard Worker            EXIRATenDialectVerifier()(ep.exported_program.graph_module)
574*523fa7a6SAndroid Build Coastguard Worker        except ExportError:
575*523fa7a6SAndroid Build Coastguard Worker            logging.info(
576*523fa7a6SAndroid Build Coastguard Worker                "If a particular operator failed core ATen IR check, please consider adding it to the exception list. "
577*523fa7a6SAndroid Build Coastguard Worker                "Add the operator to _core_aten_ops_exception_list in EdgeCompileConfig. This is the recommended way "
578*523fa7a6SAndroid Build Coastguard Worker                "to resolve this type of failure, so that the rest of the IR validation check can still be performed.\n"
579*523fa7a6SAndroid Build Coastguard Worker                "If you'd like to disable IR validation checking, please set _check_ir_validity in EdgeCompileConfig, "
580*523fa7a6SAndroid Build Coastguard Worker                "like *.to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))."
581*523fa7a6SAndroid Build Coastguard Worker            )
582*523fa7a6SAndroid Build Coastguard Worker            raise
583*523fa7a6SAndroid Build Coastguard Worker
584*523fa7a6SAndroid Build Coastguard Worker    dialect = ep.exported_program.dialect
585*523fa7a6SAndroid Build Coastguard Worker    if dialect == "ATEN":
586*523fa7a6SAndroid Build Coastguard Worker        ep = ExirExportedProgram(
587*523fa7a6SAndroid Build Coastguard Worker            ExportedProgram(
588*523fa7a6SAndroid Build Coastguard Worker                root=ep.exported_program.graph_module,
589*523fa7a6SAndroid Build Coastguard Worker                graph=ep.exported_program.graph_module.graph,
590*523fa7a6SAndroid Build Coastguard Worker                graph_signature=ep.exported_program.graph_signature,
591*523fa7a6SAndroid Build Coastguard Worker                state_dict=ep.exported_program.state_dict,
592*523fa7a6SAndroid Build Coastguard Worker                range_constraints=ep.exported_program.range_constraints,
593*523fa7a6SAndroid Build Coastguard Worker                module_call_graph=ep.exported_program.module_call_graph,
594*523fa7a6SAndroid Build Coastguard Worker                example_inputs=ep.exported_program.example_inputs,
595*523fa7a6SAndroid Build Coastguard Worker                constants=ep.exported_program.constants,
596*523fa7a6SAndroid Build Coastguard Worker                verifiers=[
597*523fa7a6SAndroid Build Coastguard Worker                    get_aten_verifier(
598*523fa7a6SAndroid Build Coastguard Worker                        config=config,
599*523fa7a6SAndroid Build Coastguard Worker                    )
600*523fa7a6SAndroid Build Coastguard Worker                ],
601*523fa7a6SAndroid Build Coastguard Worker            ),
602*523fa7a6SAndroid Build Coastguard Worker            False,
603*523fa7a6SAndroid Build Coastguard Worker        )
604*523fa7a6SAndroid Build Coastguard Worker    pre_op_replace_passes, post_op_replace_passes = _get_aten_to_edge_passes(config)
605*523fa7a6SAndroid Build Coastguard Worker
606*523fa7a6SAndroid Build Coastguard Worker    new_ep = copy.deepcopy(ep).transform(*pre_op_replace_passes)
607*523fa7a6SAndroid Build Coastguard Worker    if dialect == "ATEN":
608*523fa7a6SAndroid Build Coastguard Worker        new_ep.exported_program = lift_constant_tensor_pass(new_ep.exported_program)
609*523fa7a6SAndroid Build Coastguard Worker
610*523fa7a6SAndroid Build Coastguard Worker    new_gm = new_ep.exported_program.graph_module
611*523fa7a6SAndroid Build Coastguard Worker    if config._use_edge_ops:
612*523fa7a6SAndroid Build Coastguard Worker        new_gm_res = OpReplacePass()(new_gm)
613*523fa7a6SAndroid Build Coastguard Worker        assert new_gm_res is not None
614*523fa7a6SAndroid Build Coastguard Worker        new_gm = new_gm_res.graph_module
615*523fa7a6SAndroid Build Coastguard Worker        if not config._skip_dim_order:
616*523fa7a6SAndroid Build Coastguard Worker            new_gm_res = MemoryFormatOpsPass()(new_gm)
617*523fa7a6SAndroid Build Coastguard Worker            assert new_gm_res is not None
618*523fa7a6SAndroid Build Coastguard Worker            new_gm = new_gm_res.graph_module
619*523fa7a6SAndroid Build Coastguard Worker
620*523fa7a6SAndroid Build Coastguard Worker    for p in post_op_replace_passes:
621*523fa7a6SAndroid Build Coastguard Worker        new_gm_res = p(new_gm)
622*523fa7a6SAndroid Build Coastguard Worker        assert new_gm_res is not None
623*523fa7a6SAndroid Build Coastguard Worker        new_gm = new_gm_res.graph_module
624*523fa7a6SAndroid Build Coastguard Worker
625*523fa7a6SAndroid Build Coastguard Worker    new_ep.exported_program = ExportedProgram(
626*523fa7a6SAndroid Build Coastguard Worker        root=new_gm,
627*523fa7a6SAndroid Build Coastguard Worker        graph=new_gm.graph,
628*523fa7a6SAndroid Build Coastguard Worker        graph_signature=_get_updated_graph_signature(
629*523fa7a6SAndroid Build Coastguard Worker            new_ep.exported_program.graph_signature, new_gm
630*523fa7a6SAndroid Build Coastguard Worker        ),
631*523fa7a6SAndroid Build Coastguard Worker        state_dict=new_ep.exported_program.state_dict,
632*523fa7a6SAndroid Build Coastguard Worker        range_constraints=new_ep.exported_program.range_constraints,
633*523fa7a6SAndroid Build Coastguard Worker        module_call_graph=new_ep.exported_program.module_call_graph,
634*523fa7a6SAndroid Build Coastguard Worker        example_inputs=new_ep.exported_program.example_inputs,
635*523fa7a6SAndroid Build Coastguard Worker        constants=new_ep.exported_program.constants,
636*523fa7a6SAndroid Build Coastguard Worker        verifiers=[
637*523fa7a6SAndroid Build Coastguard Worker            EXIREdgeDialectVerifier(
638*523fa7a6SAndroid Build Coastguard Worker                edge_compile_config=config,
639*523fa7a6SAndroid Build Coastguard Worker                class_only=True,
640*523fa7a6SAndroid Build Coastguard Worker            )
641*523fa7a6SAndroid Build Coastguard Worker        ],
642*523fa7a6SAndroid Build Coastguard Worker    )
643*523fa7a6SAndroid Build Coastguard Worker    new_ep.after_to_edge_passes = True
644*523fa7a6SAndroid Build Coastguard Worker    return new_ep
645*523fa7a6SAndroid Build Coastguard Worker
646*523fa7a6SAndroid Build Coastguard Worker
647*523fa7a6SAndroid Build Coastguard Workerdef pre_memory_planning_passes(
648*523fa7a6SAndroid Build Coastguard Worker    config: ExecutorchBackendConfig, name: Optional[str] = None
649*523fa7a6SAndroid Build Coastguard Worker) -> List[PassType]:
650*523fa7a6SAndroid Build Coastguard Worker    """
651*523fa7a6SAndroid Build Coastguard Worker    Returns a list of passes to run before memory planning.
652*523fa7a6SAndroid Build Coastguard Worker    Get the sym shape eval pass based on the method name, if the pass is not in the dict, use the default pass.
653*523fa7a6SAndroid Build Coastguard Worker    """
654*523fa7a6SAndroid Build Coastguard Worker    # Handle symbolic shape eval pass
655*523fa7a6SAndroid Build Coastguard Worker    if isinstance(config.sym_shape_eval_pass, dict):
656*523fa7a6SAndroid Build Coastguard Worker        default_pass = ExecutorchBackendConfig().sym_shape_eval_pass
657*523fa7a6SAndroid Build Coastguard Worker        if not name:
658*523fa7a6SAndroid Build Coastguard Worker            sym_shape_eval_pass = default_pass
659*523fa7a6SAndroid Build Coastguard Worker        # pyre-ignore: Undefined attribute [16]
660*523fa7a6SAndroid Build Coastguard Worker        sym_shape_eval_pass = config.sym_shape_eval_pass.get(name, default_pass)
661*523fa7a6SAndroid Build Coastguard Worker    elif isinstance(config.sym_shape_eval_pass, PassBase):
662*523fa7a6SAndroid Build Coastguard Worker        sym_shape_eval_pass = config.sym_shape_eval_pass
663*523fa7a6SAndroid Build Coastguard Worker    else:
664*523fa7a6SAndroid Build Coastguard Worker        raise RuntimeError(
665*523fa7a6SAndroid Build Coastguard Worker            f"sym_shape_eval_pass must be a dict or a PassBase, got {config.sym_shape_eval_pass}"
666*523fa7a6SAndroid Build Coastguard Worker        )
667*523fa7a6SAndroid Build Coastguard Worker    if config.remove_view_copy:
668*523fa7a6SAndroid Build Coastguard Worker        return [
669*523fa7a6SAndroid Build Coastguard Worker            NormalizeViewCopyBasePass(),
670*523fa7a6SAndroid Build Coastguard Worker            dead_code_elimination_pass,
671*523fa7a6SAndroid Build Coastguard Worker            ReplaceViewCopyWithViewPass(),
672*523fa7a6SAndroid Build Coastguard Worker            sym_shape_eval_pass,
673*523fa7a6SAndroid Build Coastguard Worker            config.to_out_var_pass,
674*523fa7a6SAndroid Build Coastguard Worker        ]
675*523fa7a6SAndroid Build Coastguard Worker    else:
676*523fa7a6SAndroid Build Coastguard Worker        return [
677*523fa7a6SAndroid Build Coastguard Worker            sym_shape_eval_pass,
678*523fa7a6SAndroid Build Coastguard Worker            config.to_out_var_pass,
679*523fa7a6SAndroid Build Coastguard Worker        ]
680*523fa7a6SAndroid Build Coastguard Worker
681*523fa7a6SAndroid Build Coastguard Worker
682*523fa7a6SAndroid Build Coastguard Workerdef edge_to_executorch_passes(
683*523fa7a6SAndroid Build Coastguard Worker    config: ExecutorchBackendConfig, name: Optional[str] = None
684*523fa7a6SAndroid Build Coastguard Worker) -> List[PassType]:
685*523fa7a6SAndroid Build Coastguard Worker    """
686*523fa7a6SAndroid Build Coastguard Worker    Returns a list of passes to lower from edge to executorch.
687*523fa7a6SAndroid Build Coastguard Worker    Get the pre memory planning passes based on the method name, if the pass is not in the dict, use the default pass.
688*523fa7a6SAndroid Build Coastguard Worker    """
689*523fa7a6SAndroid Build Coastguard Worker    passes: List[PassType] = [
690*523fa7a6SAndroid Build Coastguard Worker        *config.passes,
691*523fa7a6SAndroid Build Coastguard Worker        SpecPropPass(),
692*523fa7a6SAndroid Build Coastguard Worker        # ExecuTorch backend ops are unable to handle unbacked symints. So after
693*523fa7a6SAndroid Build Coastguard Worker        # this pass, passes cannot be Interpreter-based, because it will fail if
694*523fa7a6SAndroid Build Coastguard Worker        # there exists an unbacked symint operation.
695*523fa7a6SAndroid Build Coastguard Worker        EdgeToBackendOpsPass(),
696*523fa7a6SAndroid Build Coastguard Worker        RemoveGraphAssertsPass(),
697*523fa7a6SAndroid Build Coastguard Worker    ] + pre_memory_planning_passes(config, name)
698*523fa7a6SAndroid Build Coastguard Worker
699*523fa7a6SAndroid Build Coastguard Worker    return passes
700*523fa7a6SAndroid Build Coastguard Worker
701*523fa7a6SAndroid Build Coastguard Worker
702*523fa7a6SAndroid Build Coastguard Workerdef _generate_edge_program(
703*523fa7a6SAndroid Build Coastguard Worker    name: str,
704*523fa7a6SAndroid Build Coastguard Worker    config: EdgeCompileConfig,
705*523fa7a6SAndroid Build Coastguard Worker    program: ExportedProgram,
706*523fa7a6SAndroid Build Coastguard Worker    ops_set_to_not_decompose: Optional[List[torch._ops.OpOverload]] = None,
707*523fa7a6SAndroid Build Coastguard Worker) -> ExportedProgram:
708*523fa7a6SAndroid Build Coastguard Worker    if config._check_ir_validity:
709*523fa7a6SAndroid Build Coastguard Worker        try:
710*523fa7a6SAndroid Build Coastguard Worker            EXIRATenDialectVerifier(
711*523fa7a6SAndroid Build Coastguard Worker                edge_compile_config=config,
712*523fa7a6SAndroid Build Coastguard Worker                class_only=False,
713*523fa7a6SAndroid Build Coastguard Worker                exception_list=ops_set_to_not_decompose,
714*523fa7a6SAndroid Build Coastguard Worker            )(program.graph_module)
715*523fa7a6SAndroid Build Coastguard Worker        except ExportError as e:
716*523fa7a6SAndroid Build Coastguard Worker            logging.info(f"Input program {name} is not in ATen dialect.")
717*523fa7a6SAndroid Build Coastguard Worker            raise e
718*523fa7a6SAndroid Build Coastguard Worker
719*523fa7a6SAndroid Build Coastguard Worker    pre_op_replace_passes, post_op_replace_passes = _get_aten_to_edge_passes(config)
720*523fa7a6SAndroid Build Coastguard Worker
721*523fa7a6SAndroid Build Coastguard Worker    passes = []
722*523fa7a6SAndroid Build Coastguard Worker    passes.append(
723*523fa7a6SAndroid Build Coastguard Worker        ReplaceViewOpsWithViewCopyOpsPass()
724*523fa7a6SAndroid Build Coastguard Worker    )  # TODO move inside aten_to_edge passes after all users are migrated off v1 capture
725*523fa7a6SAndroid Build Coastguard Worker    passes.extend(pre_op_replace_passes)
726*523fa7a6SAndroid Build Coastguard Worker    if config._use_edge_ops:
727*523fa7a6SAndroid Build Coastguard Worker        passes.append(OpReplacePass())
728*523fa7a6SAndroid Build Coastguard Worker        if not config._skip_dim_order:
729*523fa7a6SAndroid Build Coastguard Worker            passes.append(MemoryFormatOpsPass())
730*523fa7a6SAndroid Build Coastguard Worker
731*523fa7a6SAndroid Build Coastguard Worker    gm = program.graph_module
732*523fa7a6SAndroid Build Coastguard Worker    for p in passes:
733*523fa7a6SAndroid Build Coastguard Worker        gm_res = p(gm)
734*523fa7a6SAndroid Build Coastguard Worker        assert gm_res is not None
735*523fa7a6SAndroid Build Coastguard Worker        gm = gm_res.graph_module
736*523fa7a6SAndroid Build Coastguard Worker
737*523fa7a6SAndroid Build Coastguard Worker    edge_program = ExportedProgram(
738*523fa7a6SAndroid Build Coastguard Worker        root=gm,
739*523fa7a6SAndroid Build Coastguard Worker        graph=gm.graph,
740*523fa7a6SAndroid Build Coastguard Worker        graph_signature=_get_updated_graph_signature(program.graph_signature, gm),
741*523fa7a6SAndroid Build Coastguard Worker        state_dict=program.state_dict,
742*523fa7a6SAndroid Build Coastguard Worker        range_constraints=program.range_constraints,
743*523fa7a6SAndroid Build Coastguard Worker        module_call_graph=program.module_call_graph,
744*523fa7a6SAndroid Build Coastguard Worker        example_inputs=program.example_inputs,
745*523fa7a6SAndroid Build Coastguard Worker        constants=program.constants,
746*523fa7a6SAndroid Build Coastguard Worker        verifiers=[
747*523fa7a6SAndroid Build Coastguard Worker            EXIREdgeDialectVerifier(
748*523fa7a6SAndroid Build Coastguard Worker                edge_compile_config=config,
749*523fa7a6SAndroid Build Coastguard Worker                class_only=True,
750*523fa7a6SAndroid Build Coastguard Worker                exception_list=ops_set_to_not_decompose,
751*523fa7a6SAndroid Build Coastguard Worker            )
752*523fa7a6SAndroid Build Coastguard Worker        ],
753*523fa7a6SAndroid Build Coastguard Worker    )
754*523fa7a6SAndroid Build Coastguard Worker    # Lift the tensor constants created in ScalarToTensorPass
755*523fa7a6SAndroid Build Coastguard Worker    edge_program = lift_constant_tensor_pass(edge_program)
756*523fa7a6SAndroid Build Coastguard Worker    edge_program = _transform(edge_program, *post_op_replace_passes)
757*523fa7a6SAndroid Build Coastguard Worker
758*523fa7a6SAndroid Build Coastguard Worker    return edge_program
759*523fa7a6SAndroid Build Coastguard Worker
760*523fa7a6SAndroid Build Coastguard Worker
761*523fa7a6SAndroid Build Coastguard Workerdef _replace_aten_ops_with_transformed_ops(
762*523fa7a6SAndroid Build Coastguard Worker    name: str,
763*523fa7a6SAndroid Build Coastguard Worker    program: ExportedProgram,
764*523fa7a6SAndroid Build Coastguard Worker    partitioner,
765*523fa7a6SAndroid Build Coastguard Worker):
766*523fa7a6SAndroid Build Coastguard Worker    ops_to_not_decompose = set()
767*523fa7a6SAndroid Build Coastguard Worker    partitioners = partitioner.get(name)
768*523fa7a6SAndroid Build Coastguard Worker    if partitioners is None:
769*523fa7a6SAndroid Build Coastguard Worker        return
770*523fa7a6SAndroid Build Coastguard Worker
771*523fa7a6SAndroid Build Coastguard Worker    # Iterate through the graph and replace the aten ops with the corresponding
772*523fa7a6SAndroid Build Coastguard Worker    # transformed ops.
773*523fa7a6SAndroid Build Coastguard Worker    for partitioner in partitioners:
774*523fa7a6SAndroid Build Coastguard Worker        ops_set_to_not_decompose, check_op_support = partitioner.ops_to_not_decompose(
775*523fa7a6SAndroid Build Coastguard Worker            program
776*523fa7a6SAndroid Build Coastguard Worker        )
777*523fa7a6SAndroid Build Coastguard Worker
778*523fa7a6SAndroid Build Coastguard Worker        for op_aten in ops_set_to_not_decompose:
779*523fa7a6SAndroid Build Coastguard Worker            _register_no_decomp_op(op_aten)
780*523fa7a6SAndroid Build Coastguard Worker
781*523fa7a6SAndroid Build Coastguard Worker        for node in program.graph.nodes:
782*523fa7a6SAndroid Build Coastguard Worker            is_op_supported = check_op_support(node) if check_op_support else True
783*523fa7a6SAndroid Build Coastguard Worker            if (
784*523fa7a6SAndroid Build Coastguard Worker                node.op == "call_function"
785*523fa7a6SAndroid Build Coastguard Worker                and node.target in ops_set_to_not_decompose
786*523fa7a6SAndroid Build Coastguard Worker                and is_op_supported
787*523fa7a6SAndroid Build Coastguard Worker            ):
788*523fa7a6SAndroid Build Coastguard Worker                ops_to_not_decompose.add(node.target)
789*523fa7a6SAndroid Build Coastguard Worker                node.target = aten_op_to_transform_op[node.target]
790*523fa7a6SAndroid Build Coastguard Worker
791*523fa7a6SAndroid Build Coastguard Worker        for _, submod, _ in get_control_flow_submodules(program.graph_module):
792*523fa7a6SAndroid Build Coastguard Worker            for node in submod.graph.nodes:
793*523fa7a6SAndroid Build Coastguard Worker                is_op_supported = check_op_support(node) if check_op_support else True
794*523fa7a6SAndroid Build Coastguard Worker                if (
795*523fa7a6SAndroid Build Coastguard Worker                    node.op == "call_function"
796*523fa7a6SAndroid Build Coastguard Worker                    and node.target in ops_set_to_not_decompose
797*523fa7a6SAndroid Build Coastguard Worker                    and is_op_supported
798*523fa7a6SAndroid Build Coastguard Worker                ):
799*523fa7a6SAndroid Build Coastguard Worker                    ops_to_not_decompose.add(node.target)
800*523fa7a6SAndroid Build Coastguard Worker                    node.target = aten_op_to_transform_op[node.target]
801*523fa7a6SAndroid Build Coastguard Worker
802*523fa7a6SAndroid Build Coastguard Worker    return ops_to_not_decompose
803*523fa7a6SAndroid Build Coastguard Worker
804*523fa7a6SAndroid Build Coastguard Worker
805*523fa7a6SAndroid Build Coastguard Workerdef _restore_transformed_ops_to_aten_ops(program: ExportedProgram):
806*523fa7a6SAndroid Build Coastguard Worker    # Iterate through the graph and replace back the transformed ops with their
807*523fa7a6SAndroid Build Coastguard Worker    # corresponding aten ops.
808*523fa7a6SAndroid Build Coastguard Worker    for node in program.graph.nodes:
809*523fa7a6SAndroid Build Coastguard Worker        if node.op == "call_function" and str(node.target) in transform_op_to_aten_op:
810*523fa7a6SAndroid Build Coastguard Worker            node.target = transform_op_to_aten_op[str(node.target)]
811*523fa7a6SAndroid Build Coastguard Worker    for _, submod, _ in get_control_flow_submodules(program.graph_module):
812*523fa7a6SAndroid Build Coastguard Worker        for node in submod.graph.nodes:
813*523fa7a6SAndroid Build Coastguard Worker            if (
814*523fa7a6SAndroid Build Coastguard Worker                node.op == "call_function"
815*523fa7a6SAndroid Build Coastguard Worker                and str(node.target) in transform_op_to_aten_op
816*523fa7a6SAndroid Build Coastguard Worker            ):
817*523fa7a6SAndroid Build Coastguard Worker                node.target = transform_op_to_aten_op[str(node.target)]
818*523fa7a6SAndroid Build Coastguard Worker
819*523fa7a6SAndroid Build Coastguard Worker
820*523fa7a6SAndroid Build Coastguard Worker# Returns the op in edge_no_decomp_namespace namespace for the aten
821*523fa7a6SAndroid Build Coastguard Worker# op that is passed in.
822*523fa7a6SAndroid Build Coastguard Workerdef _get_transformed_op(op_aten):
823*523fa7a6SAndroid Build Coastguard Worker    op_name = op_aten._schema.name.split("::")[1]
824*523fa7a6SAndroid Build Coastguard Worker    overload_name = op_aten._schema.overload_name
825*523fa7a6SAndroid Build Coastguard Worker    assert hasattr(
826*523fa7a6SAndroid Build Coastguard Worker        torch.ops, edge_no_decomp_namespace
827*523fa7a6SAndroid Build Coastguard Worker    ), f"Couldn't find {edge_no_decomp_namespace} in torch.ops. Please make sure the Library has been registered."
828*523fa7a6SAndroid Build Coastguard Worker    op_namespace = getattr(torch.ops, edge_no_decomp_namespace)
829*523fa7a6SAndroid Build Coastguard Worker    op = getattr(op_namespace, op_name)
830*523fa7a6SAndroid Build Coastguard Worker    return getattr(op, overload_name)
831*523fa7a6SAndroid Build Coastguard Worker
832*523fa7a6SAndroid Build Coastguard Worker
833*523fa7a6SAndroid Build Coastguard Worker# Registers the op in edge_no_decomp_namespace namespace for the aten
834*523fa7a6SAndroid Build Coastguard Worker# op that is passed in if it is not already cached in the table.
835*523fa7a6SAndroid Build Coastguard Workerdef _register_no_decomp_op(op_aten):
836*523fa7a6SAndroid Build Coastguard Worker    # Check if the op is already cached in the table. If not, then we need to
837*523fa7a6SAndroid Build Coastguard Worker    # create a new op in the edge_no_decomp_namespace namespace.
838*523fa7a6SAndroid Build Coastguard Worker    if aten_op_to_transform_op.get(op_aten) is None and isinstance(
839*523fa7a6SAndroid Build Coastguard Worker        op_aten, torch._ops.OpOverload
840*523fa7a6SAndroid Build Coastguard Worker    ):
841*523fa7a6SAndroid Build Coastguard Worker        # Extract the schema from the aten op.
842*523fa7a6SAndroid Build Coastguard Worker        op_schema = str(op_aten._schema).split("::")[1]
843*523fa7a6SAndroid Build Coastguard Worker        op_name = op_aten._schema.name.split("::")[1]
844*523fa7a6SAndroid Build Coastguard Worker        # Define an op in the edge_no_decomp_namespace namespace with the aten schema.
845*523fa7a6SAndroid Build Coastguard Worker        lib.define(op_schema)
846*523fa7a6SAndroid Build Coastguard Worker        # Define the implementation of the op in the edge_no_decomp_namespace namespace.
847*523fa7a6SAndroid Build Coastguard Worker        # Important to note that the implementation of the op is the same as the aten op.
848*523fa7a6SAndroid Build Coastguard Worker
849*523fa7a6SAndroid Build Coastguard Worker        overload_name = op_aten._schema.overload_name
850*523fa7a6SAndroid Build Coastguard Worker        if overload_name != "":
851*523fa7a6SAndroid Build Coastguard Worker            op_name += "." + overload_name
852*523fa7a6SAndroid Build Coastguard Worker        lib.impl(op_name, op_aten, "CompositeExplicitAutograd")
853*523fa7a6SAndroid Build Coastguard Worker
854*523fa7a6SAndroid Build Coastguard Worker        # Cache the aten op and transformed op in their corresponding tables for future use.
855*523fa7a6SAndroid Build Coastguard Worker        aten_op_to_transform_op[op_aten] = _get_transformed_op(op_aten)
856*523fa7a6SAndroid Build Coastguard Worker        transform_op_to_aten_op[str(aten_op_to_transform_op[op_aten])] = op_aten
857*523fa7a6SAndroid Build Coastguard Worker
858*523fa7a6SAndroid Build Coastguard Worker
859*523fa7a6SAndroid Build Coastguard Workerdef _sanity_check_graph_for_non_decomp_ops(
860*523fa7a6SAndroid Build Coastguard Worker    name: str,
861*523fa7a6SAndroid Build Coastguard Worker    program: ExportedProgram,
862*523fa7a6SAndroid Build Coastguard Worker    ops_set_to_not_decompose,
863*523fa7a6SAndroid Build Coastguard Worker    check_op_support,
864*523fa7a6SAndroid Build Coastguard Worker    generate_error=False,
865*523fa7a6SAndroid Build Coastguard Worker    partitioner_name=None,
866*523fa7a6SAndroid Build Coastguard Worker):
867*523fa7a6SAndroid Build Coastguard Worker    warning_str = f"Found {ops_set_to_not_decompose} in edge dialect program {name}."
868*523fa7a6SAndroid Build Coastguard Worker    if partitioner_name is not None:
869*523fa7a6SAndroid Build Coastguard Worker        warning_str += f" This op was registered by the partitioner {partitioner_name} to not be decomposed."
870*523fa7a6SAndroid Build Coastguard Worker
871*523fa7a6SAndroid Build Coastguard Worker    # Check that the ops that were registered to not be decomposed are not present in the
872*523fa7a6SAndroid Build Coastguard Worker    # graph anymore as the transform passes and backends should have consumed them by now.
873*523fa7a6SAndroid Build Coastguard Worker    ops_set_to_not_decompose = {
874*523fa7a6SAndroid Build Coastguard Worker        aten_to_edge(op) for op in ops_set_to_not_decompose
875*523fa7a6SAndroid Build Coastguard Worker    }.union(ops_set_to_not_decompose)
876*523fa7a6SAndroid Build Coastguard Worker    for node in program.graph_module.graph.nodes:
877*523fa7a6SAndroid Build Coastguard Worker        is_op_supported = check_op_support(node) if check_op_support else True
878*523fa7a6SAndroid Build Coastguard Worker        if (
879*523fa7a6SAndroid Build Coastguard Worker            node.op == "call_function" and node.target in ops_set_to_not_decompose
880*523fa7a6SAndroid Build Coastguard Worker        ) and is_op_supported:
881*523fa7a6SAndroid Build Coastguard Worker            if generate_error:
882*523fa7a6SAndroid Build Coastguard Worker                raise RuntimeError(warning_str)
883*523fa7a6SAndroid Build Coastguard Worker            else:
884*523fa7a6SAndroid Build Coastguard Worker                logging.warning(warning_str)
885*523fa7a6SAndroid Build Coastguard Worker    for _, submod, _ in get_control_flow_submodules(program.graph_module):
886*523fa7a6SAndroid Build Coastguard Worker        for node in submod.graph.nodes:
887*523fa7a6SAndroid Build Coastguard Worker            is_op_supported = check_op_support(node) if check_op_support else True
888*523fa7a6SAndroid Build Coastguard Worker            if (
889*523fa7a6SAndroid Build Coastguard Worker                node.op == "call_function" and node.target in ops_set_to_not_decompose
890*523fa7a6SAndroid Build Coastguard Worker            ) and is_op_supported:
891*523fa7a6SAndroid Build Coastguard Worker                if generate_error:
892*523fa7a6SAndroid Build Coastguard Worker                    raise RuntimeError(warning_str)
893*523fa7a6SAndroid Build Coastguard Worker                else:
894*523fa7a6SAndroid Build Coastguard Worker                    logging.warning(warning_str)
895*523fa7a6SAndroid Build Coastguard Worker
896*523fa7a6SAndroid Build Coastguard Worker
897*523fa7a6SAndroid Build Coastguard Workerdef _gen_edge_manager_for_partitioners(
898*523fa7a6SAndroid Build Coastguard Worker    partitioner: Dict[str, List[Partitioner]],
899*523fa7a6SAndroid Build Coastguard Worker    aten_programs: Dict[str, ExportedProgram],
900*523fa7a6SAndroid Build Coastguard Worker    config: EdgeCompileConfig,
901*523fa7a6SAndroid Build Coastguard Worker    constant_methods: Optional[Dict[str, Any]],
902*523fa7a6SAndroid Build Coastguard Worker) -> "EdgeProgramManager":
903*523fa7a6SAndroid Build Coastguard Worker    """
904*523fa7a6SAndroid Build Coastguard Worker    Generates EdgeProgramManager for subsequent lowering to the
905*523fa7a6SAndroid Build Coastguard Worker    partitioners specified by partitioner. The EdgeProgramManager is generated from
906*523fa7a6SAndroid Build Coastguard Worker    aten_programs.
907*523fa7a6SAndroid Build Coastguard Worker
908*523fa7a6SAndroid Build Coastguard Worker    Partitioners specify what nodes should not be decomposed from the original aten programs.
909*523fa7a6SAndroid Build Coastguard Worker    This is done through two passes of run_decompositions.
910*523fa7a6SAndroid Build Coastguard Worker        - First pass preserves all aten_targets specified by partitioners to preserve
911*523fa7a6SAndroid Build Coastguard Worker          them from nested decompositions
912*523fa7a6SAndroid Build Coastguard Worker        - Second pass uses check_op fn provided by partitioners to perform additional checks
913*523fa7a6SAndroid Build Coastguard Worker          on nodes with preserved aten targets. They are then replaces with transformed ops to
914*523fa7a6SAndroid Build Coastguard Worker          keep them through the second pass of decompositions
915*523fa7a6SAndroid Build Coastguard Worker    """
916*523fa7a6SAndroid Build Coastguard Worker    ops_set_to_not_decompose_by_program = {}
917*523fa7a6SAndroid Build Coastguard Worker    edge_programs: Dict[str, ExportedProgram] = {}
918*523fa7a6SAndroid Build Coastguard Worker    for name, program in aten_programs.items():
919*523fa7a6SAndroid Build Coastguard Worker        if partitioner is not None:
920*523fa7a6SAndroid Build Coastguard Worker            # preserve all ops listed by all partitioners first
921*523fa7a6SAndroid Build Coastguard Worker            all_ops_no_decomp = set()
922*523fa7a6SAndroid Build Coastguard Worker            for curr_partitioner in partitioner.get(name, []):
923*523fa7a6SAndroid Build Coastguard Worker                curr_ops_no_decomp, _ = curr_partitioner.ops_to_not_decompose(program)
924*523fa7a6SAndroid Build Coastguard Worker                all_ops_no_decomp |= set(curr_ops_no_decomp)
925*523fa7a6SAndroid Build Coastguard Worker
926*523fa7a6SAndroid Build Coastguard Worker            table = _default_decomposition_table()
927*523fa7a6SAndroid Build Coastguard Worker
928*523fa7a6SAndroid Build Coastguard Worker            for op in all_ops_no_decomp:
929*523fa7a6SAndroid Build Coastguard Worker                table.pop(op, None)
930*523fa7a6SAndroid Build Coastguard Worker
931*523fa7a6SAndroid Build Coastguard Worker            program = program.run_decompositions(table)
932*523fa7a6SAndroid Build Coastguard Worker            # Among all the preserved aten ops, use the check_op_fn to do an additional
933*523fa7a6SAndroid Build Coastguard Worker            # check on which ops need to be preserved and which ops need to be decomposed
934*523fa7a6SAndroid Build Coastguard Worker            # Those which are truly preserved will be replaced with transformed ops
935*523fa7a6SAndroid Build Coastguard Worker            ops_set_to_not_decompose_by_program[name] = (
936*523fa7a6SAndroid Build Coastguard Worker                _replace_aten_ops_with_transformed_ops(name, program, partitioner) or []
937*523fa7a6SAndroid Build Coastguard Worker            )
938*523fa7a6SAndroid Build Coastguard Worker        program = program.run_decompositions(_default_decomposition_table())
939*523fa7a6SAndroid Build Coastguard Worker
940*523fa7a6SAndroid Build Coastguard Worker        _restore_transformed_ops_to_aten_ops(program)
941*523fa7a6SAndroid Build Coastguard Worker
942*523fa7a6SAndroid Build Coastguard Worker        edge_programs[name] = program
943*523fa7a6SAndroid Build Coastguard Worker
944*523fa7a6SAndroid Build Coastguard Worker        edge_programs[name] = _generate_edge_program(
945*523fa7a6SAndroid Build Coastguard Worker            name,
946*523fa7a6SAndroid Build Coastguard Worker            config,
947*523fa7a6SAndroid Build Coastguard Worker            program,
948*523fa7a6SAndroid Build Coastguard Worker            list(ops_set_to_not_decompose_by_program.get(name, [])),
949*523fa7a6SAndroid Build Coastguard Worker        )
950*523fa7a6SAndroid Build Coastguard Worker
951*523fa7a6SAndroid Build Coastguard Worker    edge_manager = EdgeProgramManager(
952*523fa7a6SAndroid Build Coastguard Worker        edge_programs,
953*523fa7a6SAndroid Build Coastguard Worker        constant_methods,
954*523fa7a6SAndroid Build Coastguard Worker        config,
955*523fa7a6SAndroid Build Coastguard Worker        list(set().union(*ops_set_to_not_decompose_by_program.values())),
956*523fa7a6SAndroid Build Coastguard Worker    )
957*523fa7a6SAndroid Build Coastguard Worker    return edge_manager
958*523fa7a6SAndroid Build Coastguard Worker
959*523fa7a6SAndroid Build Coastguard Worker
960*523fa7a6SAndroid Build Coastguard Workerdef to_edge_transform_and_lower(
961*523fa7a6SAndroid Build Coastguard Worker    programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
962*523fa7a6SAndroid Build Coastguard Worker    transform_passes: Optional[
963*523fa7a6SAndroid Build Coastguard Worker        Union[Sequence[PassType], Dict[str, Sequence[PassType]]]
964*523fa7a6SAndroid Build Coastguard Worker    ] = None,
965*523fa7a6SAndroid Build Coastguard Worker    partitioner: Optional[
966*523fa7a6SAndroid Build Coastguard Worker        Union[List[Partitioner], Dict[str, List[Partitioner]]]
967*523fa7a6SAndroid Build Coastguard Worker    ] = None,
968*523fa7a6SAndroid Build Coastguard Worker    constant_methods: Optional[Dict[str, Any]] = None,
969*523fa7a6SAndroid Build Coastguard Worker    compile_config: Optional[EdgeCompileConfig] = None,
970*523fa7a6SAndroid Build Coastguard Worker) -> "EdgeProgramManager":
971*523fa7a6SAndroid Build Coastguard Worker    """
972*523fa7a6SAndroid Build Coastguard Worker    :func:`to_edge_transform_and_lower` constructs an EdgeProgramManager from a set of
973*523fa7a6SAndroid Build Coastguard Worker    exported programs in ATen dialect. It differs fundamentally from to_edge in that it
974*523fa7a6SAndroid Build Coastguard Worker    combines the conversion of the ATen dialect to the edge dialect program, then running
975*523fa7a6SAndroid Build Coastguard Worker    the transformation passes and then subsequently lowering the programs to their
976*523fa7a6SAndroid Build Coastguard Worker    corresponding backends all into a single API.
977*523fa7a6SAndroid Build Coastguard Worker
978*523fa7a6SAndroid Build Coastguard Worker    This is fundamentally useful for lowering to backends that have ops registered that they
979*523fa7a6SAndroid Build Coastguard Worker    do not want to be decomposed and thus rely on matching with these non-decomposed ops. For
980*523fa7a6SAndroid Build Coastguard Worker    these sorts of backends this is the *only* API that should be used to lower to the edge
981*523fa7a6SAndroid Build Coastguard Worker    dialect. Using a combination of to_edge(...) and to_backend(...) will result in inconsistent
982*523fa7a6SAndroid Build Coastguard Worker    or wrong behavior.
983*523fa7a6SAndroid Build Coastguard Worker
984*523fa7a6SAndroid Build Coastguard Worker    This API is the primary recommended way to lower to the CPU based XNNPack backend.
985*523fa7a6SAndroid Build Coastguard Worker
986*523fa7a6SAndroid Build Coastguard Worker    Args:
987*523fa7a6SAndroid Build Coastguard Worker        programs: Can be a single ExportedProgram or a dictionary mapping function names
988*523fa7a6SAndroid Build Coastguard Worker            to their corresponding ExportedPrograms. If only a single ExportedProgram is
989*523fa7a6SAndroid Build Coastguard Worker            provided it will be assigned the name "forward".
990*523fa7a6SAndroid Build Coastguard Worker
991*523fa7a6SAndroid Build Coastguard Worker        transform_passes: The passes can either be a list of passes, or a dictionary
992*523fa7a6SAndroid Build Coastguard Worker            mapping method names to lists of passes. If it is just a list of passes, all methods
993*523fa7a6SAndroid Build Coastguard Worker            in the given EdgeProgramManager will be transformed with the provided passes. If it
994*523fa7a6SAndroid Build Coastguard Worker            is a dictionary, only method names specified in the dictionary will be transformed
995*523fa7a6SAndroid Build Coastguard Worker            with their corresponding passes.
996*523fa7a6SAndroid Build Coastguard Worker
997*523fa7a6SAndroid Build Coastguard Worker        partitioner: The partitioner can either be a Partitioner subclass instance, or a
998*523fa7a6SAndroid Build Coastguard Worker            dictionary mapping method names to Partitioner subclass instance. If it is a
999*523fa7a6SAndroid Build Coastguard Worker            Partitioner subclass, all programs in the given EdgeProgramManager will be lowered
1000*523fa7a6SAndroid Build Coastguard Worker            using the given partitioner. If it is a dictionary, only method names specified in
1001*523fa7a6SAndroid Build Coastguard Worker            the dictionary will be lowered with the given partitioner.
1002*523fa7a6SAndroid Build Coastguard Worker
1003*523fa7a6SAndroid Build Coastguard Worker        constant_methods: An optional dictionary of method name to the constant value
1004*523fa7a6SAndroid Build Coastguard Worker            returned by that method in eager mode. Often used to store config information on
1005*523fa7a6SAndroid Build Coastguard Worker            Edge models.
1006*523fa7a6SAndroid Build Coastguard Worker
1007*523fa7a6SAndroid Build Coastguard Worker        compile_config: An optional argument used to provide greater control over the
1008*523fa7a6SAndroid Build Coastguard Worker            transformation to edge dialect process.
1009*523fa7a6SAndroid Build Coastguard Worker
1010*523fa7a6SAndroid Build Coastguard Worker    Returns:
1011*523fa7a6SAndroid Build Coastguard Worker        EdgeProgramManager
1012*523fa7a6SAndroid Build Coastguard Worker    """
1013*523fa7a6SAndroid Build Coastguard Worker    assert not isinstance(constant_methods, EdgeCompileConfig)
1014*523fa7a6SAndroid Build Coastguard Worker    config = compile_config or EdgeCompileConfig()
1015*523fa7a6SAndroid Build Coastguard Worker    if not isinstance(programs, dict):
1016*523fa7a6SAndroid Build Coastguard Worker        aten_programs = {"forward": programs}
1017*523fa7a6SAndroid Build Coastguard Worker    else:
1018*523fa7a6SAndroid Build Coastguard Worker        aten_programs = programs
1019*523fa7a6SAndroid Build Coastguard Worker
1020*523fa7a6SAndroid Build Coastguard Worker    if not isinstance(partitioner, dict) and partitioner is not None:
1021*523fa7a6SAndroid Build Coastguard Worker        partitioner = {name: partitioner for name in aten_programs.keys()}
1022*523fa7a6SAndroid Build Coastguard Worker    elif partitioner is None:
1023*523fa7a6SAndroid Build Coastguard Worker        partitioner = {name: [] for name in aten_programs.keys()}
1024*523fa7a6SAndroid Build Coastguard Worker
1025*523fa7a6SAndroid Build Coastguard Worker    edge_manager = _gen_edge_manager_for_partitioners(
1026*523fa7a6SAndroid Build Coastguard Worker        partitioner, aten_programs, config, constant_methods
1027*523fa7a6SAndroid Build Coastguard Worker    )
1028*523fa7a6SAndroid Build Coastguard Worker
1029*523fa7a6SAndroid Build Coastguard Worker    if transform_passes is not None:
1030*523fa7a6SAndroid Build Coastguard Worker        edge_manager = edge_manager.transform(transform_passes)
1031*523fa7a6SAndroid Build Coastguard Worker
1032*523fa7a6SAndroid Build Coastguard Worker    if partitioner is not None:
1033*523fa7a6SAndroid Build Coastguard Worker        for name, partitioner_list in partitioner.items():
1034*523fa7a6SAndroid Build Coastguard Worker            for curr_partitioner in partitioner_list:
1035*523fa7a6SAndroid Build Coastguard Worker                edge_manager = edge_manager.to_backend({name: curr_partitioner})
1036*523fa7a6SAndroid Build Coastguard Worker
1037*523fa7a6SAndroid Build Coastguard Worker    for name, program in edge_manager._edge_programs.items():
1038*523fa7a6SAndroid Build Coastguard Worker        ops_set_to_not_decompose: Set[torch._ops.OpOverload] = set()
1039*523fa7a6SAndroid Build Coastguard Worker        partitioners = partitioner.get(name, [])
1040*523fa7a6SAndroid Build Coastguard Worker        for curr_partitioner in partitioners:
1041*523fa7a6SAndroid Build Coastguard Worker            curr_op_set, check_op_support = curr_partitioner.ops_to_not_decompose(
1042*523fa7a6SAndroid Build Coastguard Worker                program
1043*523fa7a6SAndroid Build Coastguard Worker            )
1044*523fa7a6SAndroid Build Coastguard Worker            ops_set_to_not_decompose = ops_set_to_not_decompose.union(curr_op_set)
1045*523fa7a6SAndroid Build Coastguard Worker            _sanity_check_graph_for_non_decomp_ops(
1046*523fa7a6SAndroid Build Coastguard Worker                name,
1047*523fa7a6SAndroid Build Coastguard Worker                program,
1048*523fa7a6SAndroid Build Coastguard Worker                ops_set_to_not_decompose,
1049*523fa7a6SAndroid Build Coastguard Worker                check_op_support,
1050*523fa7a6SAndroid Build Coastguard Worker                partitioner_name=curr_partitioner.__class__.__name__,
1051*523fa7a6SAndroid Build Coastguard Worker                generate_error=True,
1052*523fa7a6SAndroid Build Coastguard Worker            )
1053*523fa7a6SAndroid Build Coastguard Worker
1054*523fa7a6SAndroid Build Coastguard Worker        if config._check_ir_validity:
1055*523fa7a6SAndroid Build Coastguard Worker            EXIREdgeDialectVerifier(
1056*523fa7a6SAndroid Build Coastguard Worker                edge_compile_config=config,
1057*523fa7a6SAndroid Build Coastguard Worker                class_only=True,
1058*523fa7a6SAndroid Build Coastguard Worker                exception_list=list(ops_set_to_not_decompose),
1059*523fa7a6SAndroid Build Coastguard Worker            )()(program.graph_module)
1060*523fa7a6SAndroid Build Coastguard Worker
1061*523fa7a6SAndroid Build Coastguard Worker    return edge_manager
1062*523fa7a6SAndroid Build Coastguard Worker
1063*523fa7a6SAndroid Build Coastguard Worker
1064*523fa7a6SAndroid Build Coastguard Worker@experimental(
1065*523fa7a6SAndroid Build Coastguard Worker    """
1066*523fa7a6SAndroid Build Coastguard Worker    This is an experimental API which overloads to_edge by preserving specified ops to not be decomposed.
1067*523fa7a6SAndroid Build Coastguard Worker    This function will be combined with to_edge in the future.
1068*523fa7a6SAndroid Build Coastguard Worker    """
1069*523fa7a6SAndroid Build Coastguard Worker)
1070*523fa7a6SAndroid Build Coastguard Workerdef to_edge_with_preserved_ops(
1071*523fa7a6SAndroid Build Coastguard Worker    programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
1072*523fa7a6SAndroid Build Coastguard Worker    constant_methods: Optional[Dict[str, Any]] = None,
1073*523fa7a6SAndroid Build Coastguard Worker    compile_config: Optional[EdgeCompileConfig] = None,
1074*523fa7a6SAndroid Build Coastguard Worker    preserve_ops: Tuple[torch._ops.OpOverload, ...] = (),
1075*523fa7a6SAndroid Build Coastguard Worker) -> "EdgeProgramManager":
1076*523fa7a6SAndroid Build Coastguard Worker    """
1077*523fa7a6SAndroid Build Coastguard Worker    :func:`to_edge` constructs an EdgeProgramManager from a set of exported programs in
1078*523fa7a6SAndroid Build Coastguard Worker    ATen dialect. Upon construction those programs are transformed into edge dialect.
1079*523fa7a6SAndroid Build Coastguard Worker
1080*523fa7a6SAndroid Build Coastguard Worker    Args:
1081*523fa7a6SAndroid Build Coastguard Worker        programs: Can be a single ExportedProgram or a dictionary mapping function names to their corresponding ExportedPrograms. If only a single ExportedProgram is provided it will be assigned the name "forward".
1082*523fa7a6SAndroid Build Coastguard Worker        constant_methods: An optional dictionary of method name to the constant value returned by that method in eager mode. Often used to store config information on Edge models.
1083*523fa7a6SAndroid Build Coastguard Worker        compile_config: An optional argument used to provide greater control over the transformation to edge dialect process.
1084*523fa7a6SAndroid Build Coastguard Worker        preserve_ops: An argument used to specify ops that should not be decomposed.
1085*523fa7a6SAndroid Build Coastguard Worker
1086*523fa7a6SAndroid Build Coastguard Worker    Returns:
1087*523fa7a6SAndroid Build Coastguard Worker        EdgeProgramManager
1088*523fa7a6SAndroid Build Coastguard Worker    """
1089*523fa7a6SAndroid Build Coastguard Worker    assert not isinstance(constant_methods, EdgeCompileConfig)
1090*523fa7a6SAndroid Build Coastguard Worker    config = compile_config or EdgeCompileConfig()
1091*523fa7a6SAndroid Build Coastguard Worker    if not isinstance(programs, dict):
1092*523fa7a6SAndroid Build Coastguard Worker        aten_programs = {"forward": programs}
1093*523fa7a6SAndroid Build Coastguard Worker    else:
1094*523fa7a6SAndroid Build Coastguard Worker        aten_programs = programs
1095*523fa7a6SAndroid Build Coastguard Worker
1096*523fa7a6SAndroid Build Coastguard Worker    edge_programs: Dict[str, ExportedProgram] = {}
1097*523fa7a6SAndroid Build Coastguard Worker
1098*523fa7a6SAndroid Build Coastguard Worker    for name, program in aten_programs.items():
1099*523fa7a6SAndroid Build Coastguard Worker        # Decompose to Core ATen
1100*523fa7a6SAndroid Build Coastguard Worker        table = _default_decomposition_table()
1101*523fa7a6SAndroid Build Coastguard Worker        for op in preserve_ops:
1102*523fa7a6SAndroid Build Coastguard Worker            table.pop(op, None)
1103*523fa7a6SAndroid Build Coastguard Worker        program = program.run_decompositions(table)
1104*523fa7a6SAndroid Build Coastguard Worker        edge_programs[name] = _generate_edge_program(
1105*523fa7a6SAndroid Build Coastguard Worker            name, config, program, list(preserve_ops)
1106*523fa7a6SAndroid Build Coastguard Worker        )
1107*523fa7a6SAndroid Build Coastguard Worker
1108*523fa7a6SAndroid Build Coastguard Worker    return EdgeProgramManager(
1109*523fa7a6SAndroid Build Coastguard Worker        edge_programs, constant_methods, config, list(preserve_ops)
1110*523fa7a6SAndroid Build Coastguard Worker    )
1111*523fa7a6SAndroid Build Coastguard Worker
1112*523fa7a6SAndroid Build Coastguard Worker
1113*523fa7a6SAndroid Build Coastguard Workerdef to_edge(
1114*523fa7a6SAndroid Build Coastguard Worker    programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
1115*523fa7a6SAndroid Build Coastguard Worker    constant_methods: Optional[Dict[str, Any]] = None,
1116*523fa7a6SAndroid Build Coastguard Worker    compile_config: Optional[EdgeCompileConfig] = None,
1117*523fa7a6SAndroid Build Coastguard Worker) -> "EdgeProgramManager":
1118*523fa7a6SAndroid Build Coastguard Worker    """
1119*523fa7a6SAndroid Build Coastguard Worker    :func:`to_edge` constructs an EdgeProgramManager from a set of exported programs in
1120*523fa7a6SAndroid Build Coastguard Worker    ATen dialect. Upon construction those programs are transformed into edge dialect.
1121*523fa7a6SAndroid Build Coastguard Worker
1122*523fa7a6SAndroid Build Coastguard Worker    Args:
1123*523fa7a6SAndroid Build Coastguard Worker        programs: Can be a single ExportedProgram or a dictionary mapping function names to their corresponding ExportedPrograms. If only a single ExportedProgram is provided it will be assigned the name "forward".
1124*523fa7a6SAndroid Build Coastguard Worker
1125*523fa7a6SAndroid Build Coastguard Worker        constant_methods: An optional dictionary of method name to the constant value returned by that method in eager mode. Often used to store config information on Edge models.
1126*523fa7a6SAndroid Build Coastguard Worker
1127*523fa7a6SAndroid Build Coastguard Worker        compile_config: An optional argument used to provide greater control over the transformation to edge dialect process.
1128*523fa7a6SAndroid Build Coastguard Worker
1129*523fa7a6SAndroid Build Coastguard Worker    Returns:
1130*523fa7a6SAndroid Build Coastguard Worker        EdgeProgramManager
1131*523fa7a6SAndroid Build Coastguard Worker    """
1132*523fa7a6SAndroid Build Coastguard Worker    assert not isinstance(constant_methods, EdgeCompileConfig)
1133*523fa7a6SAndroid Build Coastguard Worker    config = compile_config or EdgeCompileConfig()
1134*523fa7a6SAndroid Build Coastguard Worker    if not isinstance(programs, dict):
1135*523fa7a6SAndroid Build Coastguard Worker        aten_programs = {"forward": programs}
1136*523fa7a6SAndroid Build Coastguard Worker    else:
1137*523fa7a6SAndroid Build Coastguard Worker        aten_programs = programs
1138*523fa7a6SAndroid Build Coastguard Worker
1139*523fa7a6SAndroid Build Coastguard Worker    edge_programs: Dict[str, ExportedProgram] = {}
1140*523fa7a6SAndroid Build Coastguard Worker
1141*523fa7a6SAndroid Build Coastguard Worker    for name, program in aten_programs.items():
1142*523fa7a6SAndroid Build Coastguard Worker        # Decompose to Core ATen
1143*523fa7a6SAndroid Build Coastguard Worker        program = program.run_decompositions(_default_decomposition_table())
1144*523fa7a6SAndroid Build Coastguard Worker        edge_programs[name] = _generate_edge_program(name, config, program)
1145*523fa7a6SAndroid Build Coastguard Worker
1146*523fa7a6SAndroid Build Coastguard Worker    return EdgeProgramManager(edge_programs, constant_methods, config)
1147*523fa7a6SAndroid Build Coastguard Worker
1148*523fa7a6SAndroid Build Coastguard Worker
1149*523fa7a6SAndroid Build Coastguard Workerclass EdgeProgramManager:
1150*523fa7a6SAndroid Build Coastguard Worker    """
1151*523fa7a6SAndroid Build Coastguard Worker    Package of one or more `ExportedPrograms` in Edge dialect. Designed to simplify
1152*523fa7a6SAndroid Build Coastguard Worker    lowering to ExecuTorch. See: https://pytorch.org/executorch/stable/ir-exir.html
1153*523fa7a6SAndroid Build Coastguard Worker
1154*523fa7a6SAndroid Build Coastguard Worker    Allows easy applications of transforms across a collection of exported programs
1155*523fa7a6SAndroid Build Coastguard Worker    including the delegation of subgraphs.
1156*523fa7a6SAndroid Build Coastguard Worker
1157*523fa7a6SAndroid Build Coastguard Worker    Manages the second link in the lowering chain of ATen -> Edge -> ExecuTorch.
1158*523fa7a6SAndroid Build Coastguard Worker    """
1159*523fa7a6SAndroid Build Coastguard Worker
1160*523fa7a6SAndroid Build Coastguard Worker    def __init__(
1161*523fa7a6SAndroid Build Coastguard Worker        self,
1162*523fa7a6SAndroid Build Coastguard Worker        edge_programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
1163*523fa7a6SAndroid Build Coastguard Worker        constant_methods: Optional[Dict[str, Any]] = None,
1164*523fa7a6SAndroid Build Coastguard Worker        compile_config: Optional[EdgeCompileConfig] = None,
1165*523fa7a6SAndroid Build Coastguard Worker        ops_set_to_not_decompose: Optional[List[torch._ops.OpOverload]] = None,
1166*523fa7a6SAndroid Build Coastguard Worker    ):
1167*523fa7a6SAndroid Build Coastguard Worker        """
1168*523fa7a6SAndroid Build Coastguard Worker        Should not be called directly by users. User should use :func:'to_edge' instead.
1169*523fa7a6SAndroid Build Coastguard Worker
1170*523fa7a6SAndroid Build Coastguard Worker        Constructs an EdgeProgramManager from an existing set of exported programs in edge dialect.
1171*523fa7a6SAndroid Build Coastguard Worker        """
1172*523fa7a6SAndroid Build Coastguard Worker        self.compile_config = compile_config or EdgeCompileConfig()
1173*523fa7a6SAndroid Build Coastguard Worker        if not isinstance(edge_programs, dict):
1174*523fa7a6SAndroid Build Coastguard Worker            edge_programs = {"forward": edge_programs}
1175*523fa7a6SAndroid Build Coastguard Worker
1176*523fa7a6SAndroid Build Coastguard Worker        for name, program in edge_programs.items():
1177*523fa7a6SAndroid Build Coastguard Worker            try:
1178*523fa7a6SAndroid Build Coastguard Worker                EXIREdgeDialectVerifier(
1179*523fa7a6SAndroid Build Coastguard Worker                    edge_compile_config=self.compile_config,
1180*523fa7a6SAndroid Build Coastguard Worker                    exception_list=ops_set_to_not_decompose,
1181*523fa7a6SAndroid Build Coastguard Worker                )(program.graph_module)
1182*523fa7a6SAndroid Build Coastguard Worker            except ExportError as e:
1183*523fa7a6SAndroid Build Coastguard Worker                logging.info(f"Input program {name} is not in aten dialect.")
1184*523fa7a6SAndroid Build Coastguard Worker                raise e
1185*523fa7a6SAndroid Build Coastguard Worker
1186*523fa7a6SAndroid Build Coastguard Worker        self._edge_programs: Dict[str, ExportedProgram] = edge_programs
1187*523fa7a6SAndroid Build Coastguard Worker        self._config_methods = constant_methods
1188*523fa7a6SAndroid Build Coastguard Worker
1189*523fa7a6SAndroid Build Coastguard Worker    @property
1190*523fa7a6SAndroid Build Coastguard Worker    def methods(self) -> Set[str]:
1191*523fa7a6SAndroid Build Coastguard Worker        """
1192*523fa7a6SAndroid Build Coastguard Worker        Returns the set of methods in this EdgeProgramManager.
1193*523fa7a6SAndroid Build Coastguard Worker        """
1194*523fa7a6SAndroid Build Coastguard Worker        return set(self._edge_programs.keys())
1195*523fa7a6SAndroid Build Coastguard Worker
1196*523fa7a6SAndroid Build Coastguard Worker    @property
1197*523fa7a6SAndroid Build Coastguard Worker    def config_methods(self) -> Set[str]:
1198*523fa7a6SAndroid Build Coastguard Worker        """
1199*523fa7a6SAndroid Build Coastguard Worker        Returns the set of config methods in this EdgeProgramManager.
1200*523fa7a6SAndroid Build Coastguard Worker        """
1201*523fa7a6SAndroid Build Coastguard Worker        return set(self._config_methods.keys()) if self._config_methods else set()
1202*523fa7a6SAndroid Build Coastguard Worker
1203*523fa7a6SAndroid Build Coastguard Worker    def exported_program(self, method_name: str = "forward") -> ExportedProgram:
1204*523fa7a6SAndroid Build Coastguard Worker        """
1205*523fa7a6SAndroid Build Coastguard Worker        Returns the ExportedProgram specified by 'method_name'.
1206*523fa7a6SAndroid Build Coastguard Worker        """
1207*523fa7a6SAndroid Build Coastguard Worker        return self._edge_programs[method_name]
1208*523fa7a6SAndroid Build Coastguard Worker
1209*523fa7a6SAndroid Build Coastguard Worker    def transform(
1210*523fa7a6SAndroid Build Coastguard Worker        self,
1211*523fa7a6SAndroid Build Coastguard Worker        passes: Union[Sequence[PassType], Dict[str, Sequence[PassType]]],
1212*523fa7a6SAndroid Build Coastguard Worker        compile_config: Optional[EdgeCompileConfig] = None,
1213*523fa7a6SAndroid Build Coastguard Worker    ) -> "EdgeProgramManager":
1214*523fa7a6SAndroid Build Coastguard Worker        """
1215*523fa7a6SAndroid Build Coastguard Worker        Transforms the program according to the provided passes.
1216*523fa7a6SAndroid Build Coastguard Worker
1217*523fa7a6SAndroid Build Coastguard Worker        Args:
1218*523fa7a6SAndroid Build Coastguard Worker            passes: The passes can either be a list of passes, or a
1219*523fa7a6SAndroid Build Coastguard Worker                dictionary mapping method names to lists of passes. If it is
1220*523fa7a6SAndroid Build Coastguard Worker                just a list of passes, all methods in the given EdgeProgramManager
1221*523fa7a6SAndroid Build Coastguard Worker                will be transformed with the provided passes. If it is a
1222*523fa7a6SAndroid Build Coastguard Worker                dictionary, only method names specified in the dictionary will be
1223*523fa7a6SAndroid Build Coastguard Worker                transformed with their corresponding passes.
1224*523fa7a6SAndroid Build Coastguard Worker            compile_config: Compile config to use for veriy the correctness of model
1225*523fa7a6SAndroid Build Coastguard Worker                graph after each pass. If not specified, the compile config of the
1226*523fa7a6SAndroid Build Coastguard Worker                calling EdgeProgramManager will be used. It will be used in as compile
1227*523fa7a6SAndroid Build Coastguard Worker                config of returned EdgeProgramManager.
1228*523fa7a6SAndroid Build Coastguard Worker
1229*523fa7a6SAndroid Build Coastguard Worker        Returns:
1230*523fa7a6SAndroid Build Coastguard Worker            EdgeProgramManager: A copy of the calling EdgeProgramManager with the
1231*523fa7a6SAndroid Build Coastguard Worker            transformations applied.
1232*523fa7a6SAndroid Build Coastguard Worker        """
1233*523fa7a6SAndroid Build Coastguard Worker        compile_config = compile_config or self.compile_config
1234*523fa7a6SAndroid Build Coastguard Worker        new_programs: Dict[str, ExportedProgram] = {}
1235*523fa7a6SAndroid Build Coastguard Worker        if isinstance(passes, dict):
1236*523fa7a6SAndroid Build Coastguard Worker            for name, program in self._edge_programs.items():
1237*523fa7a6SAndroid Build Coastguard Worker                if name in passes.keys():
1238*523fa7a6SAndroid Build Coastguard Worker                    new_programs[name] = _transform(program, *passes[name])
1239*523fa7a6SAndroid Build Coastguard Worker                    EXIREdgeDialectVerifier(edge_compile_config=compile_config)(
1240*523fa7a6SAndroid Build Coastguard Worker                        new_programs[name].graph_module
1241*523fa7a6SAndroid Build Coastguard Worker                    )
1242*523fa7a6SAndroid Build Coastguard Worker                else:
1243*523fa7a6SAndroid Build Coastguard Worker                    new_programs[name] = copy.deepcopy(program)
1244*523fa7a6SAndroid Build Coastguard Worker
1245*523fa7a6SAndroid Build Coastguard Worker        else:  # apply passes to every method
1246*523fa7a6SAndroid Build Coastguard Worker            for name, program in self._edge_programs.items():
1247*523fa7a6SAndroid Build Coastguard Worker                new_programs[name] = _transform(program, *passes)
1248*523fa7a6SAndroid Build Coastguard Worker                EXIREdgeDialectVerifier(edge_compile_config=compile_config)(
1249*523fa7a6SAndroid Build Coastguard Worker                    new_programs[name].graph_module
1250*523fa7a6SAndroid Build Coastguard Worker                )
1251*523fa7a6SAndroid Build Coastguard Worker
1252*523fa7a6SAndroid Build Coastguard Worker        return EdgeProgramManager(
1253*523fa7a6SAndroid Build Coastguard Worker            new_programs, copy.deepcopy(self._config_methods), compile_config
1254*523fa7a6SAndroid Build Coastguard Worker        )
1255*523fa7a6SAndroid Build Coastguard Worker
1256*523fa7a6SAndroid Build Coastguard Worker    def to_backend(
1257*523fa7a6SAndroid Build Coastguard Worker        self, partitioner: Union[Partitioner, Dict[str, Partitioner]]
1258*523fa7a6SAndroid Build Coastguard Worker    ) -> "EdgeProgramManager":
1259*523fa7a6SAndroid Build Coastguard Worker        """
1260*523fa7a6SAndroid Build Coastguard Worker        Returns a semantically-equivalent program to the one given as input,
1261*523fa7a6SAndroid Build Coastguard Worker        but with portions of each program in the EdgeProgramManager targeted
1262*523fa7a6SAndroid Build Coastguard Worker        for delegation as determined by the partitioner.
1263*523fa7a6SAndroid Build Coastguard Worker
1264*523fa7a6SAndroid Build Coastguard Worker        Args:
1265*523fa7a6SAndroid Build Coastguard Worker            partitioner: The partitioner can either be a Partitioner subclass instance, or a
1266*523fa7a6SAndroid Build Coastguard Worker                dictionary mapping method names to Partitioner subclass instance. If it is a
1267*523fa7a6SAndroid Build Coastguard Worker                Partitioner subclass, all programs in the given EdgeProgramManager
1268*523fa7a6SAndroid Build Coastguard Worker                will be lowered using the given partitioner. If it is a
1269*523fa7a6SAndroid Build Coastguard Worker                dictionary, only method names specified in the dictionary will be
1270*523fa7a6SAndroid Build Coastguard Worker                lowered with the given partitioner.
1271*523fa7a6SAndroid Build Coastguard Worker
1272*523fa7a6SAndroid Build Coastguard Worker                The Partitioner subclass instance is in charge with tagging portions of the
1273*523fa7a6SAndroid Build Coastguard Worker                input program for delegation. A valid partitioner must return PartitionerResult including valid
1274*523fa7a6SAndroid Build Coastguard Worker                partition_tags: Dict[str, DelegationSpec], where each key is a tag
1275*523fa7a6SAndroid Build Coastguard Worker                name and the nodes with same tag will be fused a one subgraph and
1276*523fa7a6SAndroid Build Coastguard Worker                delegated to backend specififed in delegation spec.
1277*523fa7a6SAndroid Build Coastguard Worker
1278*523fa7a6SAndroid Build Coastguard Worker        Returns:
1279*523fa7a6SAndroid Build Coastguard Worker            EdgeProgramManager: A copy of the calling EdgeProgramManager with the
1280*523fa7a6SAndroid Build Coastguard Worker            specified subgraphs lowered.
1281*523fa7a6SAndroid Build Coastguard Worker        """
1282*523fa7a6SAndroid Build Coastguard Worker        new_edge_programs: Dict[str, ExportedProgram] = {}
1283*523fa7a6SAndroid Build Coastguard Worker        if isinstance(partitioner, dict):
1284*523fa7a6SAndroid Build Coastguard Worker            for name, program in self._edge_programs.items():
1285*523fa7a6SAndroid Build Coastguard Worker                if name in partitioner.keys():
1286*523fa7a6SAndroid Build Coastguard Worker                    new_edge_programs[name] = to_backend(program, partitioner[name])
1287*523fa7a6SAndroid Build Coastguard Worker                else:
1288*523fa7a6SAndroid Build Coastguard Worker                    new_edge_programs[name] = program
1289*523fa7a6SAndroid Build Coastguard Worker
1290*523fa7a6SAndroid Build Coastguard Worker        else:  # apply partitioner to every method
1291*523fa7a6SAndroid Build Coastguard Worker            for name, program in self._edge_programs.items():
1292*523fa7a6SAndroid Build Coastguard Worker                new_edge_programs[name] = to_backend(program, partitioner)
1293*523fa7a6SAndroid Build Coastguard Worker
1294*523fa7a6SAndroid Build Coastguard Worker        config = EdgeCompileConfig(_check_ir_validity=False)
1295*523fa7a6SAndroid Build Coastguard Worker        return EdgeProgramManager(
1296*523fa7a6SAndroid Build Coastguard Worker            new_edge_programs, copy.deepcopy(self._config_methods), config
1297*523fa7a6SAndroid Build Coastguard Worker        )
1298*523fa7a6SAndroid Build Coastguard Worker
1299*523fa7a6SAndroid Build Coastguard Worker    def to_executorch(
1300*523fa7a6SAndroid Build Coastguard Worker        self,
1301*523fa7a6SAndroid Build Coastguard Worker        config: Optional[ExecutorchBackendConfig] = None,
1302*523fa7a6SAndroid Build Coastguard Worker    ) -> "ExecutorchProgramManager":
1303*523fa7a6SAndroid Build Coastguard Worker        """
1304*523fa7a6SAndroid Build Coastguard Worker        Transforms the program to the ExecuTorch backend.
1305*523fa7a6SAndroid Build Coastguard Worker
1306*523fa7a6SAndroid Build Coastguard Worker        Args:
1307*523fa7a6SAndroid Build Coastguard Worker            config: An optional argument used to provide greater control over
1308*523fa7a6SAndroid Build Coastguard Worker                the transformation to the ExecuTorch backend.
1309*523fa7a6SAndroid Build Coastguard Worker
1310*523fa7a6SAndroid Build Coastguard Worker        Returns:
1311*523fa7a6SAndroid Build Coastguard Worker            ExecutorchProgramManager: A manager representing the state of the EdgeProgramManager
1312*523fa7a6SAndroid Build Coastguard Worker            after it has been transformed to the ExecuTorch backend.
1313*523fa7a6SAndroid Build Coastguard Worker        """
1314*523fa7a6SAndroid Build Coastguard Worker        config = config if config else ExecutorchBackendConfig()
1315*523fa7a6SAndroid Build Coastguard Worker
1316*523fa7a6SAndroid Build Coastguard Worker        execution_programs: Dict[str, ExportedProgram] = {}
1317*523fa7a6SAndroid Build Coastguard Worker        for name, program in self._edge_programs.items():
1318*523fa7a6SAndroid Build Coastguard Worker            program = weights_to_outputs_pass(program)
1319*523fa7a6SAndroid Build Coastguard Worker            program = unsafe_remove_auto_functionalized_pass(program)
1320*523fa7a6SAndroid Build Coastguard Worker            gm, new_signature = insert_write_back_for_buffers_pass(program)
1321*523fa7a6SAndroid Build Coastguard Worker            new_gm = program.graph_module
1322*523fa7a6SAndroid Build Coastguard Worker            for p in edge_to_executorch_passes(config, name):
1323*523fa7a6SAndroid Build Coastguard Worker                new_gm_res = p(new_gm)
1324*523fa7a6SAndroid Build Coastguard Worker                assert new_gm_res is not None
1325*523fa7a6SAndroid Build Coastguard Worker                new_gm = new_gm_res.graph_module
1326*523fa7a6SAndroid Build Coastguard Worker                if isinstance(p, SpecPropPass):
1327*523fa7a6SAndroid Build Coastguard Worker                    # Note that this is a hacky way to get around the fact that
1328*523fa7a6SAndroid Build Coastguard Worker                    # placeholder nodes corresponding to the parameters of the graph module
1329*523fa7a6SAndroid Build Coastguard Worker                    # shall not participate in memory planning. It increases runtime memory
1330*523fa7a6SAndroid Build Coastguard Worker                    # footprint.
1331*523fa7a6SAndroid Build Coastguard Worker                    # Proper way would be to have ExportPass work with ExportedProgram
1332*523fa7a6SAndroid Build Coastguard Worker                    # instead of GraphModule. This is because ExportPass should work
1333*523fa7a6SAndroid Build Coastguard Worker                    # on top of the export artifact of torch.export whichi s ExportedProgram.
1334*523fa7a6SAndroid Build Coastguard Worker                    # Working with GraphModule does not provide all the information contained
1335*523fa7a6SAndroid Build Coastguard Worker                    # in the ExportedProgram
1336*523fa7a6SAndroid Build Coastguard Worker                    # TODO(who?)
1337*523fa7a6SAndroid Build Coastguard Worker                    p.update_placeholder_tensor_specs(program, new_gm)
1338*523fa7a6SAndroid Build Coastguard Worker
1339*523fa7a6SAndroid Build Coastguard Worker            if isinstance(config.memory_planning_pass, dict):
1340*523fa7a6SAndroid Build Coastguard Worker                memory_planning_pass = config.memory_planning_pass.get(
1341*523fa7a6SAndroid Build Coastguard Worker                    name, ExecutorchBackendConfig().memory_planning_pass
1342*523fa7a6SAndroid Build Coastguard Worker                )
1343*523fa7a6SAndroid Build Coastguard Worker            else:
1344*523fa7a6SAndroid Build Coastguard Worker                memory_planning_pass = config.memory_planning_pass
1345*523fa7a6SAndroid Build Coastguard Worker            # TODO(jakeszwe): Follow up with compiler on if the deepcopy is necessary and if so how to make it work
1346*523fa7a6SAndroid Build Coastguard Worker            if hasattr(memory_planning_pass, "run"):
1347*523fa7a6SAndroid Build Coastguard Worker                new_gm_res = memory_planning_pass.run(  # pyre-ignore[16]
1348*523fa7a6SAndroid Build Coastguard Worker                    new_gm, new_signature
1349*523fa7a6SAndroid Build Coastguard Worker                )
1350*523fa7a6SAndroid Build Coastguard Worker            else:
1351*523fa7a6SAndroid Build Coastguard Worker                new_gm_res = memory_planning_pass(new_gm)  # pyre-ignore[29]
1352*523fa7a6SAndroid Build Coastguard Worker            assert new_gm_res is not None
1353*523fa7a6SAndroid Build Coastguard Worker            new_gm = new_gm_res.graph_module
1354*523fa7a6SAndroid Build Coastguard Worker
1355*523fa7a6SAndroid Build Coastguard Worker            _copy_module(program.graph_module, new_gm)
1356*523fa7a6SAndroid Build Coastguard Worker            execution_programs[name] = program
1357*523fa7a6SAndroid Build Coastguard Worker
1358*523fa7a6SAndroid Build Coastguard Worker        return ExecutorchProgramManager(
1359*523fa7a6SAndroid Build Coastguard Worker            execution_programs, self._config_methods, config
1360*523fa7a6SAndroid Build Coastguard Worker        )
1361*523fa7a6SAndroid Build Coastguard Worker
1362*523fa7a6SAndroid Build Coastguard Worker
1363*523fa7a6SAndroid Build Coastguard Workerclass ExecutorchProgramManager:
1364*523fa7a6SAndroid Build Coastguard Worker    """
1365*523fa7a6SAndroid Build Coastguard Worker    Package of one or more `ExportedPrograms` in Execution dialect. Designed to simplify
1366*523fa7a6SAndroid Build Coastguard Worker    lowering to ExecuTorch. See: https://pytorch.org/executorch/stable/ir-exir.html
1367*523fa7a6SAndroid Build Coastguard Worker
1368*523fa7a6SAndroid Build Coastguard Worker    When the ExecutorchProgramManager is constructed the ExportedPrograms in execution dialect
1369*523fa7a6SAndroid Build Coastguard Worker    are used to form the executorch binary (in a process called emission) and then serialized
1370*523fa7a6SAndroid Build Coastguard Worker    to a buffer.
1371*523fa7a6SAndroid Build Coastguard Worker
1372*523fa7a6SAndroid Build Coastguard Worker    Manages the final link in the lowering chain of ATen -> Edge -> ExecuTorch.
1373*523fa7a6SAndroid Build Coastguard Worker    """
1374*523fa7a6SAndroid Build Coastguard Worker
1375*523fa7a6SAndroid Build Coastguard Worker    def __init__(
1376*523fa7a6SAndroid Build Coastguard Worker        self,
1377*523fa7a6SAndroid Build Coastguard Worker        execution_programs: Dict[str, ExportedProgram],
1378*523fa7a6SAndroid Build Coastguard Worker        config_methods: Optional[Dict[str, Any]] = None,
1379*523fa7a6SAndroid Build Coastguard Worker        backend_config: Optional[ExecutorchBackendConfig] = None,
1380*523fa7a6SAndroid Build Coastguard Worker    ):
1381*523fa7a6SAndroid Build Coastguard Worker        """
1382*523fa7a6SAndroid Build Coastguard Worker        End users should not call this constructor directly. Instead, they should use
1383*523fa7a6SAndroid Build Coastguard Worker        :func:'to_executorch' to construct an ExecutorchProgramManager.
1384*523fa7a6SAndroid Build Coastguard Worker
1385*523fa7a6SAndroid Build Coastguard Worker        Constructs an ExecutorchProgramManager from a set of exported programs in
1386*523fa7a6SAndroid Build Coastguard Worker        execution dialect.
1387*523fa7a6SAndroid Build Coastguard Worker
1388*523fa7a6SAndroid Build Coastguard Worker        Args:
1389*523fa7a6SAndroid Build Coastguard Worker            execution_programs: A dictionary of method name to the corresponding
1390*523fa7a6SAndroid Build Coastguard Worker            ExportedProgram.
1391*523fa7a6SAndroid Build Coastguard Worker
1392*523fa7a6SAndroid Build Coastguard Worker            config_methods: A dictionary of method name to the config value returned
1393*523fa7a6SAndroid Build Coastguard Worker            by that method in eager mode.
1394*523fa7a6SAndroid Build Coastguard Worker
1395*523fa7a6SAndroid Build Coastguard Worker            backend_config: An optional argument used to provide greater control over
1396*523fa7a6SAndroid Build Coastguard Worker            the emission and serialization.
1397*523fa7a6SAndroid Build Coastguard Worker        """
1398*523fa7a6SAndroid Build Coastguard Worker        # Set up methods
1399*523fa7a6SAndroid Build Coastguard Worker        self._execution_programs: Dict[str, ExportedProgram] = execution_programs
1400*523fa7a6SAndroid Build Coastguard Worker        self._config_methods: Optional[Dict[str, Any]] = config_methods
1401*523fa7a6SAndroid Build Coastguard Worker
1402*523fa7a6SAndroid Build Coastguard Worker        backend_config = backend_config or ExecutorchBackendConfig()
1403*523fa7a6SAndroid Build Coastguard Worker
1404*523fa7a6SAndroid Build Coastguard Worker        # Emit methods
1405*523fa7a6SAndroid Build Coastguard Worker        self._emitter_output: EmitterOutput = emit_program(
1406*523fa7a6SAndroid Build Coastguard Worker            self._execution_programs,
1407*523fa7a6SAndroid Build Coastguard Worker            backend_config.emit_stacktrace,
1408*523fa7a6SAndroid Build Coastguard Worker            self._config_methods,
1409*523fa7a6SAndroid Build Coastguard Worker        )
1410*523fa7a6SAndroid Build Coastguard Worker
1411*523fa7a6SAndroid Build Coastguard Worker        # Serialize emitter output, ready to be written to a file.
1412*523fa7a6SAndroid Build Coastguard Worker        self._pte_data: Cord = _serialize_pte_binary(
1413*523fa7a6SAndroid Build Coastguard Worker            program=self._emitter_output.program,
1414*523fa7a6SAndroid Build Coastguard Worker            mutable_data=self._emitter_output.mutable_data,
1415*523fa7a6SAndroid Build Coastguard Worker            extract_delegate_segments=backend_config.extract_delegate_segments,
1416*523fa7a6SAndroid Build Coastguard Worker            segment_alignment=backend_config.segment_alignment,
1417*523fa7a6SAndroid Build Coastguard Worker            constant_tensor_alignment=backend_config.constant_tensor_alignment,
1418*523fa7a6SAndroid Build Coastguard Worker            delegate_alignment=backend_config.delegate_alignment,
1419*523fa7a6SAndroid Build Coastguard Worker        )
1420*523fa7a6SAndroid Build Coastguard Worker        self._buffer: Optional[bytes] = None
1421*523fa7a6SAndroid Build Coastguard Worker
1422*523fa7a6SAndroid Build Coastguard Worker    @property
1423*523fa7a6SAndroid Build Coastguard Worker    def methods(self) -> Set[str]:
1424*523fa7a6SAndroid Build Coastguard Worker        """
1425*523fa7a6SAndroid Build Coastguard Worker        Returns the set of methods in this ExecutorchProgramManager.
1426*523fa7a6SAndroid Build Coastguard Worker        """
1427*523fa7a6SAndroid Build Coastguard Worker        return set(self._execution_programs.keys())
1428*523fa7a6SAndroid Build Coastguard Worker
1429*523fa7a6SAndroid Build Coastguard Worker    @property
1430*523fa7a6SAndroid Build Coastguard Worker    def config_methods(self) -> Set[str]:
1431*523fa7a6SAndroid Build Coastguard Worker        """
1432*523fa7a6SAndroid Build Coastguard Worker        Returns the set of config methods in this ExecutorchProgramManager.
1433*523fa7a6SAndroid Build Coastguard Worker        """
1434*523fa7a6SAndroid Build Coastguard Worker        return set(self._config_methods.keys()) if self._config_methods else set()
1435*523fa7a6SAndroid Build Coastguard Worker
1436*523fa7a6SAndroid Build Coastguard Worker    def exported_program(self, method_name: str = "forward") -> ExportedProgram:
1437*523fa7a6SAndroid Build Coastguard Worker        """
1438*523fa7a6SAndroid Build Coastguard Worker        Returns the ExportedProgram specified by 'method_name'.
1439*523fa7a6SAndroid Build Coastguard Worker        """
1440*523fa7a6SAndroid Build Coastguard Worker        return self._execution_programs[method_name]
1441*523fa7a6SAndroid Build Coastguard Worker
1442*523fa7a6SAndroid Build Coastguard Worker    def dump_executorch_program(
1443*523fa7a6SAndroid Build Coastguard Worker        self, verbose: bool = False, out: Optional[TextIO] = None
1444*523fa7a6SAndroid Build Coastguard Worker    ) -> None:
1445*523fa7a6SAndroid Build Coastguard Worker        """
1446*523fa7a6SAndroid Build Coastguard Worker        Prints the ExecuTorch binary in a human readable format.
1447*523fa7a6SAndroid Build Coastguard Worker
1448*523fa7a6SAndroid Build Coastguard Worker        Args:
1449*523fa7a6SAndroid Build Coastguard Worker            verbose (bool):
1450*523fa7a6SAndroid Build Coastguard Worker                If False prints the binary in a condensed format.
1451*523fa7a6SAndroid Build Coastguard Worker                If True prints the binary 1-1 with the specification in the schema.
1452*523fa7a6SAndroid Build Coastguard Worker            out:
1453*523fa7a6SAndroid Build Coastguard Worker                If None, prints to stdout.
1454*523fa7a6SAndroid Build Coastguard Worker                If non-None, writes the string to that stream object. It can be
1455*523fa7a6SAndroid Build Coastguard Worker                    a file object, a StringIO object, or any other TextIO subclass.
1456*523fa7a6SAndroid Build Coastguard Worker        """
1457*523fa7a6SAndroid Build Coastguard Worker        if verbose:
1458*523fa7a6SAndroid Build Coastguard Worker            pretty_print(self._emitter_output.program, out=out)
1459*523fa7a6SAndroid Build Coastguard Worker        else:
1460*523fa7a6SAndroid Build Coastguard Worker            print_program(self._emitter_output.program, out=out)
1461*523fa7a6SAndroid Build Coastguard Worker
1462*523fa7a6SAndroid Build Coastguard Worker    @property
1463*523fa7a6SAndroid Build Coastguard Worker    def debug_handle_map(self) -> Dict[int, Union[int, List[int]]]:
1464*523fa7a6SAndroid Build Coastguard Worker        return self._emitter_output.debug_handle_map
1465*523fa7a6SAndroid Build Coastguard Worker
1466*523fa7a6SAndroid Build Coastguard Worker    @property
1467*523fa7a6SAndroid Build Coastguard Worker    def delegate_map(
1468*523fa7a6SAndroid Build Coastguard Worker        self,
1469*523fa7a6SAndroid Build Coastguard Worker    ) -> Dict[str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]]:
1470*523fa7a6SAndroid Build Coastguard Worker        return self._emitter_output.method_to_delegate_debug_id_map
1471*523fa7a6SAndroid Build Coastguard Worker
1472*523fa7a6SAndroid Build Coastguard Worker    @property
1473*523fa7a6SAndroid Build Coastguard Worker    def executorch_program(self) -> Program:
1474*523fa7a6SAndroid Build Coastguard Worker        """
1475*523fa7a6SAndroid Build Coastguard Worker        Returns the object that represents the ExecuTorch binary before serialization.
1476*523fa7a6SAndroid Build Coastguard Worker        """
1477*523fa7a6SAndroid Build Coastguard Worker        return self._emitter_output.program
1478*523fa7a6SAndroid Build Coastguard Worker
1479*523fa7a6SAndroid Build Coastguard Worker    @property
1480*523fa7a6SAndroid Build Coastguard Worker    def buffer(self) -> bytes:
1481*523fa7a6SAndroid Build Coastguard Worker        """Returns the serialized ExecuTorch binary as a byte string.
1482*523fa7a6SAndroid Build Coastguard Worker
1483*523fa7a6SAndroid Build Coastguard Worker        Note that the call to `buffer` may allocate a very large amount of
1484*523fa7a6SAndroid Build Coastguard Worker        contiguous memory, depending on the model size. If writing to a file,
1485*523fa7a6SAndroid Build Coastguard Worker        use `write_to_file` which won't incur additional copies.
1486*523fa7a6SAndroid Build Coastguard Worker        """
1487*523fa7a6SAndroid Build Coastguard Worker        # TODO(T181494963): update pybinding to remove buffer cache, which can consume large
1488*523fa7a6SAndroid Build Coastguard Worker        # amounts of memory longer than necessary.
1489*523fa7a6SAndroid Build Coastguard Worker        if self._buffer is None:
1490*523fa7a6SAndroid Build Coastguard Worker            self._buffer = bytes(self._pte_data)
1491*523fa7a6SAndroid Build Coastguard Worker        return self._buffer
1492*523fa7a6SAndroid Build Coastguard Worker
1493*523fa7a6SAndroid Build Coastguard Worker    def write_to_file(self, open_file: io.BufferedIOBase) -> None:
1494*523fa7a6SAndroid Build Coastguard Worker        """
1495*523fa7a6SAndroid Build Coastguard Worker        Writes the serialized ExecuTorch binary to the file at `open_file`. Prefer to use this over
1496*523fa7a6SAndroid Build Coastguard Worker        `buffer`, as it writes to file without copying into a contiguous block of memory first,
1497*523fa7a6SAndroid Build Coastguard Worker        reducing the peak memory usage.
1498*523fa7a6SAndroid Build Coastguard Worker        """
1499*523fa7a6SAndroid Build Coastguard Worker        self._pte_data.write_to_file(open_file)
1500