xref: /aosp_15_r20/external/executorch/exir/backend/backend_api.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Workerimport copy
8*523fa7a6SAndroid Build Coastguard Workerimport logging
9*523fa7a6SAndroid Build Coastguard Workerfrom contextlib import contextmanager, nullcontext
10*523fa7a6SAndroid Build Coastguard Workerfrom functools import singledispatch
11*523fa7a6SAndroid Build Coastguard Workerfrom typing import Generator, List
12*523fa7a6SAndroid Build Coastguard Worker
13*523fa7a6SAndroid Build Coastguard Workerimport torch
14*523fa7a6SAndroid Build Coastguard Worker
15*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
16*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.compile_spec_schema import CompileSpec
17*523fa7a6SAndroid Build Coastguard Worker
18*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.partitioner import Partitioner, PartitionResult
19*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.utils import (
20*523fa7a6SAndroid Build Coastguard Worker    _maybe_duplicate_constant_nodes,
21*523fa7a6SAndroid Build Coastguard Worker    is_identical_graph,
22*523fa7a6SAndroid Build Coastguard Worker)
23*523fa7a6SAndroid Build Coastguard Worker
24*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.delegate import executorch_call_delegate, get_lowered_module_name
25*523fa7a6SAndroid Build Coastguard Worker
26*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.graph_module import get_control_flow_submodules
27*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.lowered_backend_module import (
28*523fa7a6SAndroid Build Coastguard Worker    _unsafe_adjust_original_program,
29*523fa7a6SAndroid Build Coastguard Worker    create_exported_program_from_submodule,
30*523fa7a6SAndroid Build Coastguard Worker    create_submodule_from_nodes,
31*523fa7a6SAndroid Build Coastguard Worker    LoweredBackendModule,
32*523fa7a6SAndroid Build Coastguard Worker)
33*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.program._fake_program import (
34*523fa7a6SAndroid Build Coastguard Worker    get_fake_program,
35*523fa7a6SAndroid Build Coastguard Worker    update_to_real_program,
36*523fa7a6SAndroid Build Coastguard Worker)
37*523fa7a6SAndroid Build Coastguard Workerfrom torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
38*523fa7a6SAndroid Build Coastguard Workerfrom torch.export import ExportedProgram
39*523fa7a6SAndroid Build Coastguard Worker
40*523fa7a6SAndroid Build Coastguard Worker
41*523fa7a6SAndroid Build Coastguard Worker@singledispatch
42*523fa7a6SAndroid Build Coastguard Workerdef to_backend(args):
43*523fa7a6SAndroid Build Coastguard Worker    """
44*523fa7a6SAndroid Build Coastguard Worker    A generic function the dispatch happens on the type of the first argument. There are currently to overloaded to_backend function:
45*523fa7a6SAndroid Build Coastguard Worker
46*523fa7a6SAndroid Build Coastguard Worker    Note: Python is dynamically-typed language and therefore cannot have proper method overloading as that requires the language to
47*523fa7a6SAndroid Build Coastguard Worker    be able to discriminate between types at compile-time. @to_backend.register will attach the function to to_backend() base on the type of the first
48*523fa7a6SAndroid Build Coastguard Worker    argument (type annotation is required). However, it can't take multiple types as arguments.
49*523fa7a6SAndroid Build Coastguard Worker
50*523fa7a6SAndroid Build Coastguard Worker    ::
51*523fa7a6SAndroid Build Coastguard Worker
52*523fa7a6SAndroid Build Coastguard Worker     def to_backend(
53*523fa7a6SAndroid Build Coastguard Worker         backend_id: str,
54*523fa7a6SAndroid Build Coastguard Worker         edge_graph_module: ExportedProgram,
55*523fa7a6SAndroid Build Coastguard Worker         compile_specs: List[CompileSpec],
56*523fa7a6SAndroid Build Coastguard Worker     ) -> LoweredBackendModule:
57*523fa7a6SAndroid Build Coastguard Worker
58*523fa7a6SAndroid Build Coastguard Worker     def to_backend(
59*523fa7a6SAndroid Build Coastguard Worker         graph_module: torch.fx.GraphModule,
60*523fa7a6SAndroid Build Coastguard Worker         partitioner: Type[TPartitioner],
61*523fa7a6SAndroid Build Coastguard Worker     ) -> torch.fx.GraphModule
62*523fa7a6SAndroid Build Coastguard Worker    """
63*523fa7a6SAndroid Build Coastguard Worker    pass
64*523fa7a6SAndroid Build Coastguard Worker
65*523fa7a6SAndroid Build Coastguard Worker
66*523fa7a6SAndroid Build Coastguard Worker@to_backend.register
67*523fa7a6SAndroid Build Coastguard Workerdef _(
68*523fa7a6SAndroid Build Coastguard Worker    backend_id: str,
69*523fa7a6SAndroid Build Coastguard Worker    edge_program: ExportedProgram,
70*523fa7a6SAndroid Build Coastguard Worker    compile_specs: List[CompileSpec],
71*523fa7a6SAndroid Build Coastguard Worker) -> LoweredBackendModule:
72*523fa7a6SAndroid Build Coastguard Worker    """
73*523fa7a6SAndroid Build Coastguard Worker    Add overloaded implementations for to_backend:
74*523fa7a6SAndroid Build Coastguard Worker
75*523fa7a6SAndroid Build Coastguard Worker    ::
76*523fa7a6SAndroid Build Coastguard Worker
77*523fa7a6SAndroid Build Coastguard Worker     def to_backend(
78*523fa7a6SAndroid Build Coastguard Worker         backend_id: str,
79*523fa7a6SAndroid Build Coastguard Worker         edge_program: ExportedProgram,
80*523fa7a6SAndroid Build Coastguard Worker         compile_specs: List[CompileSpec],
81*523fa7a6SAndroid Build Coastguard Worker     ) -> LoweredBackendModule:
82*523fa7a6SAndroid Build Coastguard Worker
83*523fa7a6SAndroid Build Coastguard Worker
84*523fa7a6SAndroid Build Coastguard Worker    Requires the passed in exported program in Edge dialect to be executed in
85*523fa7a6SAndroid Build Coastguard Worker    the backend identified by backend_id. The forward method of the given
86*523fa7a6SAndroid Build Coastguard Worker    edge_graph_module will be targeted for execution.
87*523fa7a6SAndroid Build Coastguard Worker
88*523fa7a6SAndroid Build Coastguard Worker    Args:
89*523fa7a6SAndroid Build Coastguard Worker        backend_id: The backend identifier.
90*523fa7a6SAndroid Build Coastguard Worker        exported_program: An exported program in Edge dialect to target for
91*523fa7a6SAndroid Build Coastguard Worker        lowering to the backend.
92*523fa7a6SAndroid Build Coastguard Worker        compile_specs: A list of backend-specific objects with static
93*523fa7a6SAndroid Build Coastguard Worker            metadata to configure the "compilation" process (e.g. it could be
94*523fa7a6SAndroid Build Coastguard Worker            another dictionary itself).
95*523fa7a6SAndroid Build Coastguard Worker
96*523fa7a6SAndroid Build Coastguard Worker    Returns:
97*523fa7a6SAndroid Build Coastguard Worker        LoweredBackendModule: A Module that has been lowered to the target backend.
98*523fa7a6SAndroid Build Coastguard Worker        Internally, the lowered Module contains these special attributes:
99*523fa7a6SAndroid Build Coastguard Worker        backend_id (str: backend id), __processed_module__ (str: a compiled module)
100*523fa7a6SAndroid Build Coastguard Worker        compile_spec, original_module (original exported program)
101*523fa7a6SAndroid Build Coastguard Worker
102*523fa7a6SAndroid Build Coastguard Worker    Raises:
103*523fa7a6SAndroid Build Coastguard Worker        NotImplementedError: The backend is not implemented (e.g. it was not found).
104*523fa7a6SAndroid Build Coastguard Worker        This exception is derived from RuntimeError and should be caught accordingly.
105*523fa7a6SAndroid Build Coastguard Worker        RuntimeError: The module cannot be processed by the backend.
106*523fa7a6SAndroid Build Coastguard Worker    """
107*523fa7a6SAndroid Build Coastguard Worker    assert isinstance(edge_program, ExportedProgram)
108*523fa7a6SAndroid Build Coastguard Worker
109*523fa7a6SAndroid Build Coastguard Worker    # All backend implementation are final, so we don't need to consider nested subclasses.
110*523fa7a6SAndroid Build Coastguard Worker    for cls in BackendDetails.__subclasses__():
111*523fa7a6SAndroid Build Coastguard Worker        if backend_id == cls.__name__:
112*523fa7a6SAndroid Build Coastguard Worker            copied_edge_program = copy.deepcopy(edge_program)
113*523fa7a6SAndroid Build Coastguard Worker            preprocess_result: PreprocessResult = cls.preprocess(
114*523fa7a6SAndroid Build Coastguard Worker                copied_edge_program,
115*523fa7a6SAndroid Build Coastguard Worker                compile_specs,
116*523fa7a6SAndroid Build Coastguard Worker            )
117*523fa7a6SAndroid Build Coastguard Worker            lowered_module = LoweredBackendModule(
118*523fa7a6SAndroid Build Coastguard Worker                edge_program=edge_program,
119*523fa7a6SAndroid Build Coastguard Worker                backend_id=backend_id,
120*523fa7a6SAndroid Build Coastguard Worker                processed_bytes=preprocess_result.processed_bytes,
121*523fa7a6SAndroid Build Coastguard Worker                compile_specs=compile_specs,
122*523fa7a6SAndroid Build Coastguard Worker            )
123*523fa7a6SAndroid Build Coastguard Worker            lowered_module.meta = {
124*523fa7a6SAndroid Build Coastguard Worker                "debug_handle_map": preprocess_result.debug_handle_map
125*523fa7a6SAndroid Build Coastguard Worker            }
126*523fa7a6SAndroid Build Coastguard Worker            return lowered_module
127*523fa7a6SAndroid Build Coastguard Worker    raise NotImplementedError(f"Backend {backend_id} was not found.")
128*523fa7a6SAndroid Build Coastguard Worker
129*523fa7a6SAndroid Build Coastguard Worker
130*523fa7a6SAndroid Build Coastguard Worker_ENABLE_VALIDATION: bool = True
131*523fa7a6SAndroid Build Coastguard Worker
132*523fa7a6SAndroid Build Coastguard Worker
133*523fa7a6SAndroid Build Coastguard Workerdef disable_validation() -> None:
134*523fa7a6SAndroid Build Coastguard Worker    """Disables validation"""
135*523fa7a6SAndroid Build Coastguard Worker    global _ENABLE_VALIDATION
136*523fa7a6SAndroid Build Coastguard Worker    _ENABLE_VALIDATION = False
137*523fa7a6SAndroid Build Coastguard Worker
138*523fa7a6SAndroid Build Coastguard Worker
139*523fa7a6SAndroid Build Coastguard Worker@contextmanager
140*523fa7a6SAndroid Build Coastguard Workerdef validation_disabled() -> Generator[None, None, None]:
141*523fa7a6SAndroid Build Coastguard Worker    """
142*523fa7a6SAndroid Build Coastguard Worker    Disables checking functions (ex. if the partitioned graph is identical to
143*523fa7a6SAndroid Build Coastguard Worker    the original graph). This context manager should only be used in certain
144*523fa7a6SAndroid Build Coastguard Worker    scenarios (such as when it has been profiled that checks are taking too
145*523fa7a6SAndroid Build Coastguard Worker    long, and are not necessarily needed)
146*523fa7a6SAndroid Build Coastguard Worker    """
147*523fa7a6SAndroid Build Coastguard Worker    global _ENABLE_VALIDATION
148*523fa7a6SAndroid Build Coastguard Worker    existing_setting = _ENABLE_VALIDATION
149*523fa7a6SAndroid Build Coastguard Worker    disable_validation()
150*523fa7a6SAndroid Build Coastguard Worker    try:
151*523fa7a6SAndroid Build Coastguard Worker        yield
152*523fa7a6SAndroid Build Coastguard Worker    finally:
153*523fa7a6SAndroid Build Coastguard Worker        _ENABLE_VALIDATION = existing_setting
154*523fa7a6SAndroid Build Coastguard Worker
155*523fa7a6SAndroid Build Coastguard Worker
156*523fa7a6SAndroid Build Coastguard Workerdef _get_node_list_with_same_tag(
157*523fa7a6SAndroid Build Coastguard Worker    tagged_graph_module: torch.fx.GraphModule,
158*523fa7a6SAndroid Build Coastguard Worker    tag: str,
159*523fa7a6SAndroid Build Coastguard Worker    owning_program: ExportedProgram,
160*523fa7a6SAndroid Build Coastguard Worker) -> List[torch.fx.Node]:
161*523fa7a6SAndroid Build Coastguard Worker    """
162*523fa7a6SAndroid Build Coastguard Worker    Return a list of nodes with the same tag.
163*523fa7a6SAndroid Build Coastguard Worker    """
164*523fa7a6SAndroid Build Coastguard Worker    node_list = []
165*523fa7a6SAndroid Build Coastguard Worker
166*523fa7a6SAndroid Build Coastguard Worker    for node in tagged_graph_module.graph.nodes:
167*523fa7a6SAndroid Build Coastguard Worker        if node.meta.get("delegation_tag", "") == tag:
168*523fa7a6SAndroid Build Coastguard Worker            if node.op == "output":
169*523fa7a6SAndroid Build Coastguard Worker                raise RuntimeError(f"output node {node} should not be tagged")
170*523fa7a6SAndroid Build Coastguard Worker            if node.op == "placeholder":
171*523fa7a6SAndroid Build Coastguard Worker                if (
172*523fa7a6SAndroid Build Coastguard Worker                    not is_param(owning_program, node)
173*523fa7a6SAndroid Build Coastguard Worker                    and not is_buffer(owning_program, node)
174*523fa7a6SAndroid Build Coastguard Worker                    and not is_lifted_tensor_constant(owning_program, node)
175*523fa7a6SAndroid Build Coastguard Worker                ):
176*523fa7a6SAndroid Build Coastguard Worker                    raise RuntimeError(
177*523fa7a6SAndroid Build Coastguard Worker                        f"placeholder node for non-params, non-buffer, and non-tensor constants should not be tagged: {node} "
178*523fa7a6SAndroid Build Coastguard Worker                    )
179*523fa7a6SAndroid Build Coastguard Worker                else:
180*523fa7a6SAndroid Build Coastguard Worker                    # check that the users all belong to the same tag
181*523fa7a6SAndroid Build Coastguard Worker                    for user in node.users:
182*523fa7a6SAndroid Build Coastguard Worker                        users_tag = user.meta.get("delegation_tag", None)
183*523fa7a6SAndroid Build Coastguard Worker                        if users_tag != tag:
184*523fa7a6SAndroid Build Coastguard Worker                            raise RuntimeError(
185*523fa7a6SAndroid Build Coastguard Worker                                f"constant data node ({node}) is tagged with ({tag}) but has user ({user}) which has tag ({users_tag})"
186*523fa7a6SAndroid Build Coastguard Worker                            )
187*523fa7a6SAndroid Build Coastguard Worker            node_list.append(node)
188*523fa7a6SAndroid Build Coastguard Worker    return node_list
189*523fa7a6SAndroid Build Coastguard Worker
190*523fa7a6SAndroid Build Coastguard Worker
191*523fa7a6SAndroid Build Coastguard Workerdef _partition_and_lower_one_graph_module(
192*523fa7a6SAndroid Build Coastguard Worker    tagged_graph_module: torch.fx.GraphModule,
193*523fa7a6SAndroid Build Coastguard Worker    partition_result: PartitionResult,
194*523fa7a6SAndroid Build Coastguard Worker    owning_program: ExportedProgram,
195*523fa7a6SAndroid Build Coastguard Worker    is_submodule: bool,
196*523fa7a6SAndroid Build Coastguard Worker) -> torch.fx.GraphModule:
197*523fa7a6SAndroid Build Coastguard Worker    """
198*523fa7a6SAndroid Build Coastguard Worker    Partitioned and lowered the graph module based on the partition tag, this is to handle one graph module.
199*523fa7a6SAndroid Build Coastguard Worker    """
200*523fa7a6SAndroid Build Coastguard Worker    for tag, delegation_spec in partition_result.partition_tags.items():
201*523fa7a6SAndroid Build Coastguard Worker        # Create partition with nodes containing this tag. There should only be
202*523fa7a6SAndroid Build Coastguard Worker        # one contained submodule per tag
203*523fa7a6SAndroid Build Coastguard Worker        node_list = _get_node_list_with_same_tag(
204*523fa7a6SAndroid Build Coastguard Worker            tagged_graph_module, tag, owning_program
205*523fa7a6SAndroid Build Coastguard Worker        )
206*523fa7a6SAndroid Build Coastguard Worker
207*523fa7a6SAndroid Build Coastguard Worker        if len(node_list) == 0:
208*523fa7a6SAndroid Build Coastguard Worker            logging.debug(f"Did not find any nodes for tag {tag}")
209*523fa7a6SAndroid Build Coastguard Worker            continue
210*523fa7a6SAndroid Build Coastguard Worker
211*523fa7a6SAndroid Build Coastguard Worker        logging.debug(f"For tag {tag}, found nodes {node_list}")
212*523fa7a6SAndroid Build Coastguard Worker        # Tag the nodes that are params as buffers, so we can order the submodule as (Parms + Buffers) (User Inputs)
213*523fa7a6SAndroid Build Coastguard Worker
214*523fa7a6SAndroid Build Coastguard Worker        replace_ctx = (
215*523fa7a6SAndroid Build Coastguard Worker            tagged_graph_module._set_replace_hook(
216*523fa7a6SAndroid Build Coastguard Worker                owning_program.graph_signature.get_replace_hook()
217*523fa7a6SAndroid Build Coastguard Worker            )
218*523fa7a6SAndroid Build Coastguard Worker            if not is_submodule
219*523fa7a6SAndroid Build Coastguard Worker            else nullcontext()
220*523fa7a6SAndroid Build Coastguard Worker        )
221*523fa7a6SAndroid Build Coastguard Worker        with replace_ctx:
222*523fa7a6SAndroid Build Coastguard Worker            submodule, call_module_node = create_submodule_from_nodes(
223*523fa7a6SAndroid Build Coastguard Worker                tagged_graph_module, node_list, tag
224*523fa7a6SAndroid Build Coastguard Worker            )
225*523fa7a6SAndroid Build Coastguard Worker
226*523fa7a6SAndroid Build Coastguard Worker        tagged_graph_module_output_node = [
227*523fa7a6SAndroid Build Coastguard Worker            node for node in tagged_graph_module.graph.nodes if node.op == "output"
228*523fa7a6SAndroid Build Coastguard Worker        ][0]
229*523fa7a6SAndroid Build Coastguard Worker        submodule_output_node = [
230*523fa7a6SAndroid Build Coastguard Worker            node for node in submodule.graph.nodes if node.op == "output"
231*523fa7a6SAndroid Build Coastguard Worker        ][0]
232*523fa7a6SAndroid Build Coastguard Worker        # Copy the output node meta from the original output node, because
233*523fa7a6SAndroid Build Coastguard Worker        # create_submodule_from_nodes doesn't cover the meta field
234*523fa7a6SAndroid Build Coastguard Worker        submodule_output_node.meta = tagged_graph_module_output_node.meta
235*523fa7a6SAndroid Build Coastguard Worker        logging.debug(f"Partitioned graph module: {tagged_graph_module}")
236*523fa7a6SAndroid Build Coastguard Worker
237*523fa7a6SAndroid Build Coastguard Worker        (
238*523fa7a6SAndroid Build Coastguard Worker            submodule_program,
239*523fa7a6SAndroid Build Coastguard Worker            toplevel_input_specs_to_delete,
240*523fa7a6SAndroid Build Coastguard Worker            toplevel_output_specs_to_delete,
241*523fa7a6SAndroid Build Coastguard Worker        ) = create_exported_program_from_submodule(
242*523fa7a6SAndroid Build Coastguard Worker            submodule,
243*523fa7a6SAndroid Build Coastguard Worker            owning_program,
244*523fa7a6SAndroid Build Coastguard Worker            tag,
245*523fa7a6SAndroid Build Coastguard Worker            call_module_node,
246*523fa7a6SAndroid Build Coastguard Worker            is_submodule,
247*523fa7a6SAndroid Build Coastguard Worker        )
248*523fa7a6SAndroid Build Coastguard Worker
249*523fa7a6SAndroid Build Coastguard Worker        lowered_submodule = to_backend(
250*523fa7a6SAndroid Build Coastguard Worker            delegation_spec.backend_id,
251*523fa7a6SAndroid Build Coastguard Worker            submodule_program,
252*523fa7a6SAndroid Build Coastguard Worker            delegation_spec.compile_specs,
253*523fa7a6SAndroid Build Coastguard Worker        )
254*523fa7a6SAndroid Build Coastguard Worker
255*523fa7a6SAndroid Build Coastguard Worker        # call delegate args should only use user_inputs
256*523fa7a6SAndroid Build Coastguard Worker        call_delegate_args = []
257*523fa7a6SAndroid Build Coastguard Worker        # Preserve input order as user_inputs
258*523fa7a6SAndroid Build Coastguard Worker        for inp_name in submodule_program.graph_signature.user_inputs:
259*523fa7a6SAndroid Build Coastguard Worker            for inp_node in call_module_node.all_input_nodes:
260*523fa7a6SAndroid Build Coastguard Worker                if inp_node.name == inp_name:
261*523fa7a6SAndroid Build Coastguard Worker                    call_delegate_args.append(inp_node)
262*523fa7a6SAndroid Build Coastguard Worker                    break
263*523fa7a6SAndroid Build Coastguard Worker
264*523fa7a6SAndroid Build Coastguard Worker        def generate_debug_handle(ep: ExportedProgram) -> int:
265*523fa7a6SAndroid Build Coastguard Worker            """
266*523fa7a6SAndroid Build Coastguard Worker            Generate a debug handle for the given ExportedProgram.
267*523fa7a6SAndroid Build Coastguard Worker            """
268*523fa7a6SAndroid Build Coastguard Worker            debug_handle = 0
269*523fa7a6SAndroid Build Coastguard Worker            for node in ep.graph_module.graph.nodes:
270*523fa7a6SAndroid Build Coastguard Worker                debug_handle = max(debug_handle, node.meta.get("debug_handle", 0))
271*523fa7a6SAndroid Build Coastguard Worker            return debug_handle + 1
272*523fa7a6SAndroid Build Coastguard Worker
273*523fa7a6SAndroid Build Coastguard Worker        # Replace the partitioned submodule with a lowered submodule
274*523fa7a6SAndroid Build Coastguard Worker        # Add call_method node with function "forward"
275*523fa7a6SAndroid Build Coastguard Worker        with tagged_graph_module.graph.inserting_before(call_module_node):
276*523fa7a6SAndroid Build Coastguard Worker            lowered_name = get_lowered_module_name(
277*523fa7a6SAndroid Build Coastguard Worker                tagged_graph_module, lowered_submodule
278*523fa7a6SAndroid Build Coastguard Worker            )
279*523fa7a6SAndroid Build Coastguard Worker            lowered_node = tagged_graph_module.graph.get_attr(lowered_name)
280*523fa7a6SAndroid Build Coastguard Worker            call_delegate_node = tagged_graph_module.graph.call_function(
281*523fa7a6SAndroid Build Coastguard Worker                executorch_call_delegate,
282*523fa7a6SAndroid Build Coastguard Worker                (lowered_node,) + tuple(call_delegate_args),
283*523fa7a6SAndroid Build Coastguard Worker                call_module_node.kwargs,
284*523fa7a6SAndroid Build Coastguard Worker            )
285*523fa7a6SAndroid Build Coastguard Worker            call_delegate_node.meta["debug_handle"] = generate_debug_handle(
286*523fa7a6SAndroid Build Coastguard Worker                owning_program
287*523fa7a6SAndroid Build Coastguard Worker            )
288*523fa7a6SAndroid Build Coastguard Worker            call_delegate_node.meta["val"] = submodule_output_node.meta["val"]
289*523fa7a6SAndroid Build Coastguard Worker            call_module_node.replace_all_uses_with(call_delegate_node)
290*523fa7a6SAndroid Build Coastguard Worker            tagged_graph_module.graph.erase_node(call_module_node)
291*523fa7a6SAndroid Build Coastguard Worker
292*523fa7a6SAndroid Build Coastguard Worker        if is_submodule:
293*523fa7a6SAndroid Build Coastguard Worker            assert len(toplevel_input_specs_to_delete) == 0
294*523fa7a6SAndroid Build Coastguard Worker            assert len(toplevel_output_specs_to_delete) == 0
295*523fa7a6SAndroid Build Coastguard Worker        elif (
296*523fa7a6SAndroid Build Coastguard Worker            len(toplevel_input_specs_to_delete) > 0
297*523fa7a6SAndroid Build Coastguard Worker            or len(toplevel_output_specs_to_delete) > 0
298*523fa7a6SAndroid Build Coastguard Worker        ):
299*523fa7a6SAndroid Build Coastguard Worker            _unsafe_adjust_original_program(
300*523fa7a6SAndroid Build Coastguard Worker                owning_program,
301*523fa7a6SAndroid Build Coastguard Worker                call_delegate_node,
302*523fa7a6SAndroid Build Coastguard Worker                toplevel_input_specs_to_delete,
303*523fa7a6SAndroid Build Coastguard Worker                toplevel_output_specs_to_delete,
304*523fa7a6SAndroid Build Coastguard Worker            )
305*523fa7a6SAndroid Build Coastguard Worker
306*523fa7a6SAndroid Build Coastguard Worker    return tagged_graph_module
307*523fa7a6SAndroid Build Coastguard Worker
308*523fa7a6SAndroid Build Coastguard Worker
309*523fa7a6SAndroid Build Coastguard Workerdef _partition_and_lower(
310*523fa7a6SAndroid Build Coastguard Worker    tagged_graph_module: torch.fx.GraphModule,
311*523fa7a6SAndroid Build Coastguard Worker    partition_result: PartitionResult,
312*523fa7a6SAndroid Build Coastguard Worker    owning_program: ExportedProgram,
313*523fa7a6SAndroid Build Coastguard Worker    is_submodule: bool = False,
314*523fa7a6SAndroid Build Coastguard Worker) -> torch.fx.GraphModule:
315*523fa7a6SAndroid Build Coastguard Worker    """
316*523fa7a6SAndroid Build Coastguard Worker    Partitions the graph module into submodules based on tags, and then lowered the nodes with the same tag as one lowered module, including the submodule from control flow
317*523fa7a6SAndroid Build Coastguard Worker    """
318*523fa7a6SAndroid Build Coastguard Worker
319*523fa7a6SAndroid Build Coastguard Worker    partitioned_module = _partition_and_lower_one_graph_module(
320*523fa7a6SAndroid Build Coastguard Worker        tagged_graph_module, partition_result, owning_program, is_submodule
321*523fa7a6SAndroid Build Coastguard Worker    )
322*523fa7a6SAndroid Build Coastguard Worker
323*523fa7a6SAndroid Build Coastguard Worker    # Recursively partition and lower for submodules
324*523fa7a6SAndroid Build Coastguard Worker    for name, submod, _node in get_control_flow_submodules(partitioned_module):
325*523fa7a6SAndroid Build Coastguard Worker        partitioned_submodule = _partition_and_lower(
326*523fa7a6SAndroid Build Coastguard Worker            submod, partition_result, owning_program, is_submodule=True
327*523fa7a6SAndroid Build Coastguard Worker        )
328*523fa7a6SAndroid Build Coastguard Worker        tagged_graph_module.add_module(name, partitioned_submodule)
329*523fa7a6SAndroid Build Coastguard Worker
330*523fa7a6SAndroid Build Coastguard Worker    return tagged_graph_module
331*523fa7a6SAndroid Build Coastguard Worker
332*523fa7a6SAndroid Build Coastguard Worker
333*523fa7a6SAndroid Build Coastguard Worker@to_backend.register
334*523fa7a6SAndroid Build Coastguard Workerdef _(
335*523fa7a6SAndroid Build Coastguard Worker    edge_program: ExportedProgram,
336*523fa7a6SAndroid Build Coastguard Worker    partitioner_instance: Partitioner,
337*523fa7a6SAndroid Build Coastguard Worker) -> ExportedProgram:
338*523fa7a6SAndroid Build Coastguard Worker    """
339*523fa7a6SAndroid Build Coastguard Worker    Add overloaded implementations for to_backend:
340*523fa7a6SAndroid Build Coastguard Worker
341*523fa7a6SAndroid Build Coastguard Worker    ::
342*523fa7a6SAndroid Build Coastguard Worker
343*523fa7a6SAndroid Build Coastguard Worker     def to_backend(
344*523fa7a6SAndroid Build Coastguard Worker         edge_program: ExportedProgram,
345*523fa7a6SAndroid Build Coastguard Worker         partitioner: Partitioner,
346*523fa7a6SAndroid Build Coastguard Worker     ) -> ExportedProgram:
347*523fa7a6SAndroid Build Coastguard Worker
348*523fa7a6SAndroid Build Coastguard Worker    Returns a semantically-equivalent program to the one given as input (represented
349*523fa7a6SAndroid Build Coastguard Worker    as a graph module in Edge dialect), but with portions of the program targeted for
350*523fa7a6SAndroid Build Coastguard Worker    delegation as determined by the partitioner.
351*523fa7a6SAndroid Build Coastguard Worker
352*523fa7a6SAndroid Build Coastguard Worker    Args:
353*523fa7a6SAndroid Build Coastguard Worker        ExportedProgram: Program in Edge dialect.
354*523fa7a6SAndroid Build Coastguard Worker
355*523fa7a6SAndroid Build Coastguard Worker        partitioner: An instance of the partitioner, in charge with tagging
356*523fa7a6SAndroid Build Coastguard Worker        portions of the input program for delegation. A valid partitioner must return PartitionerResult
357*523fa7a6SAndroid Build Coastguard Worker        including both tagged exported program and partitioner_tag: Dict[str, DelegationSpec], where each key is a tag name and
358*523fa7a6SAndroid Build Coastguard Worker        the nodes with same tag will be fused a one subgraph and delegated to backend specififed in delegation spec.
359*523fa7a6SAndroid Build Coastguard Worker
360*523fa7a6SAndroid Build Coastguard Worker
361*523fa7a6SAndroid Build Coastguard Worker    Returns:
362*523fa7a6SAndroid Build Coastguard Worker        ExportedProgram: The input program, with some portions targeted for delegation.
363*523fa7a6SAndroid Build Coastguard Worker    """
364*523fa7a6SAndroid Build Coastguard Worker    edge_program._validate()
365*523fa7a6SAndroid Build Coastguard Worker
366*523fa7a6SAndroid Build Coastguard Worker    # Use fake program, with FakeTensors in the state dict, to avoid copying large constant values.
367*523fa7a6SAndroid Build Coastguard Worker    # Fall back to deepcopy if no fake mode is found. TODO(T182910699): Remove this fallback.
368*523fa7a6SAndroid Build Coastguard Worker    try:
369*523fa7a6SAndroid Build Coastguard Worker        fake_edge_program = get_fake_program(edge_program)
370*523fa7a6SAndroid Build Coastguard Worker    except Exception as e:
371*523fa7a6SAndroid Build Coastguard Worker        logging.warning(
372*523fa7a6SAndroid Build Coastguard Worker            f"Error in get_fake_program for graph {edge_program.graph_module}, fallback to deepcopy: {e}"
373*523fa7a6SAndroid Build Coastguard Worker        )
374*523fa7a6SAndroid Build Coastguard Worker        fake_edge_program = copy.deepcopy(edge_program)
375*523fa7a6SAndroid Build Coastguard Worker    partitioner_result = partitioner_instance(fake_edge_program)
376*523fa7a6SAndroid Build Coastguard Worker    tagged_exported_program = partitioner_result.tagged_exported_program
377*523fa7a6SAndroid Build Coastguard Worker
378*523fa7a6SAndroid Build Coastguard Worker    # Check that the partitioner did not modify the original graph
379*523fa7a6SAndroid Build Coastguard Worker    if _ENABLE_VALIDATION:
380*523fa7a6SAndroid Build Coastguard Worker        assert is_identical_graph(
381*523fa7a6SAndroid Build Coastguard Worker            tagged_exported_program.graph_module,
382*523fa7a6SAndroid Build Coastguard Worker            edge_program.graph_module,
383*523fa7a6SAndroid Build Coastguard Worker        ), f"The partitioner {partitioner_instance} should not modify the graph module"
384*523fa7a6SAndroid Build Coastguard Worker    else:
385*523fa7a6SAndroid Build Coastguard Worker        logging.warning("Disabled validating the partitioner.")
386*523fa7a6SAndroid Build Coastguard Worker
387*523fa7a6SAndroid Build Coastguard Worker    assert (
388*523fa7a6SAndroid Build Coastguard Worker        partitioner_result.partition_tags is not None
389*523fa7a6SAndroid Build Coastguard Worker    ), f"Partitioner {partitioner_instance} needs a `partition_tags` field containing a mapping of tags to delegate spec"
390*523fa7a6SAndroid Build Coastguard Worker
391*523fa7a6SAndroid Build Coastguard Worker    update_to_real_program(tagged_exported_program, edge_program)
392*523fa7a6SAndroid Build Coastguard Worker
393*523fa7a6SAndroid Build Coastguard Worker    for tag, _ in partitioner_result.partition_tags.items():
394*523fa7a6SAndroid Build Coastguard Worker        _maybe_duplicate_constant_nodes(tagged_exported_program, tag)
395*523fa7a6SAndroid Build Coastguard Worker
396*523fa7a6SAndroid Build Coastguard Worker    tagged_graph_module = _partition_and_lower(
397*523fa7a6SAndroid Build Coastguard Worker        tagged_exported_program.graph_module,
398*523fa7a6SAndroid Build Coastguard Worker        partitioner_result,
399*523fa7a6SAndroid Build Coastguard Worker        tagged_exported_program,
400*523fa7a6SAndroid Build Coastguard Worker    )
401*523fa7a6SAndroid Build Coastguard Worker
402*523fa7a6SAndroid Build Coastguard Worker    return ExportedProgram(
403*523fa7a6SAndroid Build Coastguard Worker        root=tagged_graph_module,
404*523fa7a6SAndroid Build Coastguard Worker        graph=tagged_graph_module.graph,
405*523fa7a6SAndroid Build Coastguard Worker        graph_signature=tagged_exported_program.graph_signature,
406*523fa7a6SAndroid Build Coastguard Worker        state_dict=tagged_exported_program.state_dict,
407*523fa7a6SAndroid Build Coastguard Worker        range_constraints=copy.deepcopy(tagged_exported_program.range_constraints),
408*523fa7a6SAndroid Build Coastguard Worker        module_call_graph=copy.deepcopy(tagged_exported_program.module_call_graph),
409*523fa7a6SAndroid Build Coastguard Worker        example_inputs=None,
410*523fa7a6SAndroid Build Coastguard Worker        constants=tagged_exported_program.constants,
411*523fa7a6SAndroid Build Coastguard Worker        verifiers=[tagged_exported_program.verifier],
412*523fa7a6SAndroid Build Coastguard Worker    )
413