xref: /aosp_15_r20/external/executorch/exir/lowered_backend_module.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9import copy
10import operator
11from collections import defaultdict
12from typing import Any, Dict, List, Optional, Set, Tuple, Union
13
14import torch
15import torch.utils._pytree as pytree
16from executorch.exir._serialize import _serialize_pte_binary
17from executorch.exir.backend.compile_spec_schema import CompileSpec
18from executorch.exir.delegate import executorch_call_delegate, get_lowered_module_name
19from executorch.exir.emit import emit_program
20
21from executorch.exir.graph_module import _get_submodule
22
23from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass
24from executorch.exir.passes.spec_prop_pass import make_spec, SpecPropPass
25from executorch.exir.schema import Program
26
27from executorch.exir.tracer import Value
28from torch._library.fake_class_registry import FakeScriptObject
29
30from torch._subclasses import FakeTensor
31from torch.export.exported_program import (
32    ConstantArgument,
33    ExportedProgram,
34    ExportGraphSignature,
35    InputKind,
36    InputSpec,
37    ModuleCallEntry,
38    ModuleCallSignature,
39    OutputKind,
40    OutputSpec,
41    TensorArgument,
42)
43from torch.fx.passes.utils.fuser_utils import (
44    erase_nodes,
45    fuse_as_graphmodule,
46    insert_subgm,
47    legalize_graph,
48    NodeList,
49    topo_sort,
50)
51
52
53class LoweredBackendModule(torch.nn.Module):
54    """
55    A subclass of nn.Module that is generated for modules containing
56    delegated functions. This is can be created by calling `to_backend`.
57    """
58
59    _backend_id: str  # The backend's name
60    _processed_bytes: bytes  # The delegate blobs created from backend.preprocess
61    _compile_specs: List[
62        CompileSpec
63    ]  # A list of backend-specific objects with static metadata to configure the "compilation" process.
64    _original_exported_program: ExportedProgram  # The original EXIR module
65
66    def __init__(
67        self,
68        edge_program: ExportedProgram,
69        backend_id: str,
70        processed_bytes: bytes,
71        compile_specs: List[CompileSpec],
72    ) -> None:
73        super().__init__()
74        self._original_exported_program = edge_program
75        self._backend_id = backend_id
76        self._processed_bytes = processed_bytes
77        self._compile_specs = compile_specs
78
79    # pyre-ignore
80    def __deepcopy__(self, memo: Optional[Dict[int, Any]]) -> "LoweredBackendModule":
81        # Copy exported program
82        copied_program = ExportedProgram(
83            root=copy.deepcopy(self._original_exported_program.graph_module),
84            graph=copy.deepcopy(self._original_exported_program.graph),
85            graph_signature=copy.deepcopy(
86                self._original_exported_program.graph_signature
87            ),
88            state_dict=self._original_exported_program.state_dict,
89            range_constraints=copy.deepcopy(
90                self._original_exported_program.range_constraints
91            ),
92            module_call_graph=copy.deepcopy(
93                self._original_exported_program.module_call_graph
94            ),
95            constants=self._original_exported_program.constants,
96            verifiers=[copy.deepcopy(self._original_exported_program.verifier)],
97        )
98
99        res = LoweredBackendModule(
100            edge_program=copied_program,
101            backend_id=self._backend_id,
102            processed_bytes=self._processed_bytes,
103            compile_specs=copy.deepcopy(self._compile_specs, memo),
104        )
105        # pyre-fixme[16]: `LoweredBackendModule` has no attribute `meta`.
106        res.meta = copy.copy(getattr(self, "meta", {}))
107        return res
108
109    @property
110    def backend_id(self) -> str:
111        """
112        Returns the backends name.
113        """
114        return self._backend_id
115
116    @property
117    def processed_bytes(self) -> bytes:
118        """
119        Returns the delegate blob created from backend.preprocess
120        """
121        return self._processed_bytes
122
123    @property
124    def compile_specs(self) -> List[CompileSpec]:
125        """
126        Returns a list of backend-specific objects with static metadata to configure the "compilation" process.
127        """
128        return self._compile_specs
129
130    @property
131    def original_module(self) -> ExportedProgram:
132        """
133        Returns the original EXIR module
134        """
135        return self._original_exported_program
136
137    # TODO(chenlai): consolidate the seriailization config with serialize_to_flatbuffer api
138    def buffer(
139        self,
140        extract_delegate_segments: bool = False,
141        segment_alignment: int = 128,
142        constant_tensor_alignment: Optional[int] = None,
143        delegate_alignment: Optional[int] = None,
144        memory_planning: MemoryPlanningPass = None,  # pyre-fixme[9]
145    ) -> bytes:
146        """
147        Returns a buffer containing the serialized ExecuTorch binary.
148        """
149        # TODO(T181463742): avoid calling bytes(..) which incurs large copies.
150        out = bytes(
151            _serialize_pte_binary(
152                program=self.program(memory_planning=memory_planning),
153                extract_delegate_segments=extract_delegate_segments,
154                segment_alignment=segment_alignment,
155                constant_tensor_alignment=constant_tensor_alignment,
156                delegate_alignment=delegate_alignment,
157            )
158        )
159        return out
160
161    # TODO(chenlai): re-consider recapture instead of manually constructing the program because
162    # the meta data construction is done manually.
163    def program(
164        self,
165        emit_stacktrace: bool = False,
166        memory_planning: MemoryPlanningPass = None,  # pyre-fixme[9]
167    ) -> Program:
168        # Fix autodpes introuces cyclic dependencies:
169        # program -> verifier -> lowered_backend_module -> program
170        # @manual
171        from executorch.exir.program._program import (
172            _get_updated_graph_signature,
173            _transform,
174        )
175
176        """
177        Returns the object that represents the ExecuTorch binary before serialization.
178        """
179        # Creates a new module based on the original module. The original module will
180        # look something like following:
181        #
182        # opcode         name                 target            args                                        kwargs
183        # -------------  -------------------  ----------------  ------------------------------------------  --------
184        # placeholder    arg0_1               arg0_1            ()                                          {}
185        # placeholder    arg1_1               arg1_1            ()                                          {}
186        # call_function  aten_repeat_default  *                 (arg1_1, [4, 1])                            {}
187        # call_function  aten_mul_tensor      *                 (aten_repeat_default, aten_repeat_default)  {}
188        # call_function  aten_add_tensor      *                 (arg1_1, arg1_1)                            {}
189        # output         output               output            ([aten_mul_tensor, aten_add_tensor],)       {}
190        #
191        # if the whole module is lowered, the resulting lowered module look like
192        #
193        # opcode         name                      target                       args                                kwargs
194        # -------------  ------------------------  ---------------------------  ----------------------------------  --------
195        # placeholder    arg0_1                    arg0_1                       ()                                  {}
196        # placeholder    arg1_1                    arg1_1                       ()                                  {}
197        # get_attr       lowered_module_0          lowered_module_0             ()                                  {}
198        # call_function  executorch_call_delegate  executorch_call_delegate     (lowered_module_0, arg0_1, arg1_1)  {}
199        # call_function  getitem                   <built-in function getitem>  (executorch_call_delegate, 0)       {}
200        # call_function  getitem_1                 <built-in function getitem>  (executorch_call_delegate, 1)       {}
201        # output         output_1                  output                       ([getitem, getitem_1],)             {}
202        #
203        # We'll remove all call_function nodes, insert an call_delegate node, inserting getitems nodes to get the result for call_delegate node
204        # and return the list of getitems as the output
205
206        lowered_exported_program = copy.deepcopy(self._original_exported_program)
207
208        # The real input nodes are the ones not buffer or parameter
209        all_input_nodes = [
210            node
211            for node in lowered_exported_program.graph.nodes
212            if (
213                node.op == "placeholder"
214                and node.name
215                not in lowered_exported_program.graph_signature.inputs_to_buffers
216                and node.name
217                not in lowered_exported_program.graph_signature.inputs_to_parameters
218            )
219        ]
220
221        output_node = [
222            node for node in lowered_exported_program.graph.nodes if node.op == "output"
223        ]
224        assert len(output_node) == 1, "There should be only one output node"
225
226        # Step 1. Cleaning up the graph before inserting the call_delegate node
227        # Remove the original output node
228        lowered_exported_program.graph.erase_node(output_node[0])
229
230        # Remove all the everything else except the input
231        for node in reversed(lowered_exported_program.graph.nodes):
232            if node.op != "placeholder":
233                lowered_exported_program.graph.erase_node(node)
234
235        # Find placeholders that are parameters or buffers, remove them from the main graph
236        for node in lowered_exported_program.graph.nodes:
237            if node.op == "placeholder" and (
238                node.name in lowered_exported_program.graph_signature.inputs_to_buffers
239                or node.name
240                in lowered_exported_program.graph_signature.inputs_to_parameters
241            ):
242                lowered_exported_program.graph.erase_node(node)
243
244        # Step 2. Start constructing the graph
245        lowered_name = get_lowered_module_name(
246            lowered_exported_program.graph_module, self
247        )
248        # Insert the lowered module to the graph module as an attibute
249        lowered_node = lowered_exported_program.graph.get_attr(lowered_name)
250
251        # Insert a call_delegate node to the graph module, with arguments from the arg list
252        delegate_node = lowered_exported_program.graph.call_function(
253            executorch_call_delegate, (lowered_node, *all_input_nodes)
254        )
255        # Get the output list. Since the output node is a tuple of list, like ([aten_mul_tensor, aten_add_tensor],)
256        # We add some handling logic to get the list `[aten_mul_tensor, aten_add_tensor]` properly
257        original_output_nodes = [
258            node
259            for node in self._original_exported_program.graph.nodes
260            if node.op == "output"
261        ][0].args[0]
262
263        delegate_node.meta["spec"] = tuple(
264            [make_spec(node.meta["val"]) for node in original_output_nodes]
265        )
266        delegate_node.meta["val"] = tuple(
267            [node.meta["val"] for node in original_output_nodes]
268        )
269
270        # The getitem nodes that are going to be inserted to the lowered graph module
271        getitem_nodes = []
272        for i in range(len(original_output_nodes)):
273            getitem_node = lowered_exported_program.graph.call_function(
274                operator.getitem,
275                args=(delegate_node, i),
276            )
277            getitem_node.meta["val"] = delegate_node.meta["val"][i]
278            getitem_nodes.append(getitem_node)
279        lowered_exported_program.graph.output(getitem_nodes)
280
281        lowered_exported_program.graph_module.recompile()
282        lowered_exported_program.graph.lint()
283
284        # Users output will be the get items nodes instead
285        output_specs = [
286            OutputSpec(
287                kind=OutputKind.USER_OUTPUT,
288                arg=TensorArgument(name=getitem_node.name),
289                target=None,
290            )
291            for getitem_node in getitem_nodes
292        ]
293        # All data are consumed by the delegates so they should be removed from the state dict.
294        inputs_to_parameters = (
295            lowered_exported_program.graph_signature.inputs_to_parameters
296        )
297        inputs_to_buffers = lowered_exported_program.graph_signature.inputs_to_buffers
298        input_specs = [
299            InputSpec(
300                kind=InputKind.USER_INPUT,
301                arg=TensorArgument(name=node.name),
302                target=None,
303            )
304            for user_input in lowered_exported_program.graph_signature.user_inputs
305            if user_input not in inputs_to_parameters
306            and user_input not in inputs_to_buffers
307        ]
308
309        # Double check the ExportedProgram data(especially everything except graph) is good
310        exported_program = ExportedProgram(
311            root=lowered_exported_program.graph_module,
312            graph=lowered_exported_program.graph,
313            graph_signature=_get_updated_graph_signature(
314                ExportGraphSignature(
315                    input_specs=input_specs, output_specs=output_specs
316                ),
317                lowered_exported_program.graph_module,
318            ),
319            # TODO: May need to set lowered_exported_program.call_spec = CallSpec(None, None)
320            # somewhere as we should pass it a list of tensors to the lowered module and output a
321            # list of tensors. Putting call_spec=lowered_exported_program.call_spec is correct here as the
322            # inputs/outputs to the toplevel program will be in the format of the eager module.
323            state_dict={},  # None because all data are consumed by delegate
324            range_constraints=lowered_exported_program.range_constraints,
325            module_call_graph=lowered_exported_program.module_call_graph,
326            example_inputs=None,
327            verifiers=[lowered_exported_program.verifier],
328        )
329        if memory_planning is None:
330            memory_planning = MemoryPlanningPass()
331        exported_program = _transform(exported_program, SpecPropPass(), memory_planning)
332        emitted_program = emit_program(
333            exported_program, emit_stacktrace=emit_stacktrace
334        ).program
335        return emitted_program
336
337    # Used to patch each delegated function with a call_delegate call
338    # @staticmethod
339    def forward(
340        self,
341        *args: Value,
342        **kwargs: Tuple[Value, ...],
343    ) -> Value:
344        return executorch_call_delegate(self, *args)
345
346
347# TODO(zhxchen17) Try ExportPass
348def _fixup_output_node(gm: torch.fx.GraphModule) -> None:
349    for node in reversed(gm.graph.nodes):
350        if node.op == "output":
351            with gm.graph.inserting_before(node):
352                assert len(node.args) == 1
353                outputs = node.args[0]
354                if isinstance(outputs, torch.fx.Node):
355                    val = outputs.meta.get("val")
356                    if isinstance(val, list):
357                        # If a list is returned, in some cases it is represented as a
358                        # singular node, like `split_copy_tensor` but EXIR will return a
359                        # opened-up list like `[getitem1, getitem2]`
360                        outputs = [
361                            torch.fx.Proxy(outputs)[i].node for i in range(len(val))
362                        ]
363            returns, out_spec = pytree.tree_flatten(outputs)
364            node.args = (returns,)
365            return
366
367
368def arrange_graph_placeholders(
369    gm: torch.fx.GraphModule, owning_program: ExportedProgram
370) -> torch.fx.GraphModule:
371    """
372    Modifies the graph of the given graphmodule with one that contains the same nodes as the original,
373    but with placeholders in order of (Params + Buffers) (User Inputs)
374
375    This is used by the delegate api which disturbs the placeholder ordering when creating a submodule
376    from partitioned nodes
377
378    Args:
379        gm: The graph module that we want arranged
380        owning_program: ExportedProgram that the submodule (gm) belongs to
381
382    Returns:
383        The graph module in-placed arranged
384    """
385    new_graph = torch.fx.Graph()
386
387    node_map = {}  # mapping of nodes from old graph to new graph
388
389    graph_sign = owning_program.graph_signature
390
391    # Add all placeholders into the graph first:
392    param_nodes = []
393    buffer_nodes = []
394    input_nodes = []
395    for node in gm.graph.nodes:
396        if node.op != "placeholder":
397            continue
398
399        if node.name in graph_sign.inputs_to_parameters:
400            param_nodes.append(node)
401        elif node.name in graph_sign.inputs_to_buffers:
402            buffer_nodes.append(node)
403        else:
404            input_nodes.append(node)
405
406    for param_node in param_nodes:
407        new_node = new_graph.node_copy(param_node, lambda x: node_map[x])
408        node_map[param_node] = new_node
409    for buffer_node in buffer_nodes:
410        new_node = new_graph.node_copy(buffer_node, lambda x: node_map[x])
411        node_map[buffer_node] = new_node
412    for input_node in input_nodes:
413        new_node = new_graph.node_copy(input_node, lambda x: node_map[x])
414        node_map[input_node] = new_node
415
416    # Now add all the other nodes in order
417    for node in gm.graph.nodes:
418        if node.op == "placeholder":
419            continue
420
421        new_node = new_graph.node_copy(node, lambda x: node_map[x])
422        node_map[node] = new_node
423
424    # lint to ensure correctness
425    new_graph.lint()
426
427    new_graph._codegen = gm.graph._codegen
428    gm.graph = new_graph
429
430    return gm
431
432
433# TODO Don't regenerate new signature manually.
434def _get_new_signature(  # noqa: C901
435    original_program: ExportedProgram,
436    gm: torch.fx.GraphModule,
437    call_module_node: torch.fx.Node,
438    tag: str,
439    is_submodule: bool = False,
440) -> Tuple[
441    ExportGraphSignature,
442    Dict[str, Union[torch.Tensor, torch.nn.Parameter]],
443    Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]],
444    Dict[str, InputSpec],
445    Dict[str, OutputSpec],
446]:
447    """
448    Args:
449        original_program: The original program that we are paritioning
450        gm: The partitioned graph module.
451        call_module_node: The node in the original program that is calling the
452            partitioned graph module.
453        tag: The tag being used for this partitioned submodule. This is used to
454            tell if a particular parameter/buffer/constant node is being tagged,
455            aka consumed by the delegate.
456        is_submodule: True if we are currently partitioning inside of a
457            submodule (like cond's submodule). If we are inside of a submodule,
458            we do not care about consuming params/buffers.
459
460    Returns:
461
462        new_signature (ExportGraphSignature): The new signature for the
463            partitioned graph module.
464        new_state_dict (Dict[str, Union[torch.Tensor, torch.nn.Parameter]]): The
465            new state dict containing the consumed params/buffers.
466        new_constants (Dict[str, Union[torch.Tensor, FakeScriptObject,
467            torch.ScriptObject]]): The new constants table containing the
468            consumed constants .
469        input_specs_to_delete (Dict[str, InputSpec]): The input specs that have
470            been consumed by the delegate (param/buffer input nodes) and should
471            be removed from the toplevel ExportedProgram.
472        output_specs_to_delete (Dict[str, InputSpec]): The output specs that have
473            been consumed by the delegate (buffer mutation nodes) and should be
474            removed from the toplevel ExportedProgram.
475    """
476    old_signature = original_program.graph_signature
477
478    input_specs = []
479    output_specs = []
480    input_specs_to_delete = {}
481    output_specs_to_delete = {}
482    new_state_dict = {}
483    new_constants = {}
484
485    # If we are within a submodule, we do not need to care about consuming
486    # parameter/buffers
487    input_node_to_sig: Dict[str, InputSpec] = (
488        {input_spec.arg.name: input_spec for input_spec in old_signature.input_specs}
489        if not is_submodule
490        else {}
491    )
492
493    toplevel_output_node_to_sig: Dict[str, List[OutputSpec]] = defaultdict(list)
494    if not is_submodule:
495        for output_spec in old_signature.output_specs:
496            toplevel_output_node_to_sig[output_spec.arg.name].append(output_spec)
497
498    for node in gm.graph.nodes:
499        if node.op == "placeholder":
500
501            if node.name not in input_node_to_sig:
502                input_specs.append(
503                    InputSpec(
504                        kind=InputKind.USER_INPUT,
505                        arg=TensorArgument(name=node.name),
506                        target=None,
507                    )
508                )
509                continue
510
511            orig_input_spec = input_node_to_sig[node.name]
512
513            if not isinstance(orig_input_spec.arg, TensorArgument):
514                input_specs.append(orig_input_spec)
515
516            elif node.meta.get("delegation_tag", None) == tag:
517                input_specs.append(orig_input_spec)
518
519                if orig_input_spec.kind == InputKind.USER_INPUT:
520                    continue
521
522                # The following input specs are all attributes that should be
523                # consumed by the delegate, so we want to remove it from the
524                # toplevel module input/output
525                input_specs_to_delete[node.name] = orig_input_spec
526
527                input_target = orig_input_spec.target
528                if input_target in original_program.state_dict:
529                    assert orig_input_spec.kind in (
530                        InputKind.PARAMETER,
531                        InputKind.BUFFER,
532                    )
533
534                    new_state_dict[input_target] = original_program.state_dict[
535                        input_target
536                    ]
537                elif input_target in original_program.constants:
538                    assert orig_input_spec.kind in (
539                        InputKind.CONSTANT_TENSOR,
540                        InputKind.CUSTOM_OBJ,
541                        InputKind.BUFFER,
542                    )
543
544                    new_constants[input_target] = original_program.constants[
545                        input_target
546                    ]
547                else:
548                    raise RuntimeError(f"Invalid input spec {orig_input_spec} received")
549
550            else:
551                input_specs.append(
552                    InputSpec(
553                        kind=InputKind.USER_INPUT,
554                        arg=TensorArgument(name=node.name),
555                        target=None,
556                    )
557                )
558
559        if node.op == "output":
560            buffer_mutation_idxs: Dict[int, List[OutputSpec]] = defaultdict(list)
561            for user in call_module_node.users.keys():
562                if user.name in toplevel_output_node_to_sig:
563                    assert (
564                        user.op == "call_function" and user.target == operator.getitem
565                    ), f"Invalid user {user}, node.op is {user.op} and node.target is {user.target}"
566                    getitem_idx = user.args[1]
567                    assert isinstance(
568                        getitem_idx, int
569                    ), f"Invalid getitem type: {type(getitem_idx)}"
570                    buffer_mutation_idxs[getitem_idx].extend(
571                        toplevel_output_node_to_sig[user.name]
572                    )
573
574            for i, output_node in enumerate(node.args[0]):
575                if i in buffer_mutation_idxs:
576                    assert isinstance(output_node, torch.fx.Node)
577                    orig_output_specs = buffer_mutation_idxs[i]
578
579                    if any(
580                        orig_output_spec.kind == OutputKind.BUFFER_MUTATION
581                        and orig_output_spec.target in new_state_dict
582                        for orig_output_spec in orig_output_specs
583                    ):
584                        # If the delegate wants to consume the buffer, then the
585                        # delegate should also consume the buffer mutation
586                        # (output spec would be a BUFFER_MUTATION).  Otherwise
587                        # the delegate will just return the result of the
588                        # mutation as a USER_OUTPUT.
589
590                        orig_output_spec = [
591                            orig_output_spec
592                            for orig_output_spec in orig_output_specs
593                            if orig_output_spec.kind == OutputKind.BUFFER_MUTATION
594                            and orig_output_spec.target in new_state_dict
595                        ][0]
596
597                        assert len(orig_output_specs) == 1, (
598                            f"Constant {orig_output_spec.target} was tagged to be "
599                            "consumed by the buffer, and was found to also contain "
600                            "a buffer mutation. However this buffer mutation node "
601                            "was found to also be used as other types of outputs "
602                            "which is currently not supported. Please file an "
603                            "issue on Github. \n\n"
604                            f"The toplevel program: {original_program}\n"
605                        )
606                        output_specs.append(
607                            OutputSpec(
608                                kind=OutputKind.BUFFER_MUTATION,
609                                arg=TensorArgument(name=output_node.name),
610                                target=orig_output_spec.target,
611                            )
612                        )
613                        output_specs_to_delete[orig_output_spec.arg.name] = (
614                            orig_output_spec
615                        )
616                    else:
617                        output_specs.append(
618                            OutputSpec(
619                                kind=OutputKind.USER_OUTPUT,
620                                arg=TensorArgument(name=output_node.name),
621                                target=None,
622                            )
623                        )
624
625                elif not isinstance(output_node, torch.fx.Node):
626                    output_specs.append(
627                        OutputSpec(
628                            kind=OutputKind.USER_OUTPUT,
629                            arg=ConstantArgument(name="", value=output_node),
630                            target=None,
631                        )
632                    )
633
634                else:
635                    output_specs.append(
636                        OutputSpec(
637                            kind=OutputKind.USER_OUTPUT,
638                            arg=TensorArgument(name=output_node.name),
639                            target=None,
640                        )
641                    )
642
643    new_signature = ExportGraphSignature(
644        input_specs=input_specs, output_specs=output_specs
645    )
646
647    return (
648        new_signature,
649        new_state_dict,
650        new_constants,
651        input_specs_to_delete,
652        output_specs_to_delete,
653    )
654
655
656def create_exported_program_from_submodule(
657    submodule: torch.fx.GraphModule,
658    owning_program: ExportedProgram,
659    tag: str,
660    call_module_node: torch.fx.Node,
661    is_submodule: bool,
662) -> Tuple[ExportedProgram, Dict[str, InputSpec], Dict[str, OutputSpec]]:
663    """
664    Creates an ExportedProgram from the given submodule using the parameters and buffers
665    from the top-level owning program
666
667    Args:
668        submodule: submodule to create and exported program from
669        owning_program: exported program containing the parameters and buffers used within
670            the submodule
671
672    Returns:
673        The ExportedProgram created from submodule
674        input_specs_to_delete (Dict[str, InputSpec]): The input specs that have
675            been consumed by the delegate (param/buffer input nodes) and should
676            be removed from the toplevel ExportedProgram.
677        output_specs_to_delete (Dict[str, InputSpec]): The output specs that have
678            been consumed by the delegate (buffer mutation nodes) and should be
679            removed from the toplevel ExportedProgram.
680    """
681    # Arrange the submodule's placeholders in order
682    submodule = arrange_graph_placeholders(submodule, owning_program)
683
684    # TODO: we probably need to arrange the outputs wrt buffer mutations.
685
686    # Get updated graph signature
687    (
688        subgraph_signature,
689        subgraph_state_dict,
690        subgraph_constants,
691        toplevel_input_specs_to_delete,
692        toplevel_output_specs_to_delete,
693    ) = _get_new_signature(
694        owning_program, submodule, call_module_node, tag, is_submodule
695    )
696
697    in_spec = pytree.tree_flatten((tuple(subgraph_signature.user_inputs), {}))[1]
698    out_spec = pytree.tree_flatten(subgraph_signature.user_outputs)[1]
699
700    return (
701        ExportedProgram(
702            root=submodule,
703            graph=submodule.graph,
704            graph_signature=subgraph_signature,
705            state_dict=subgraph_state_dict,
706            range_constraints=copy.deepcopy(owning_program.range_constraints),
707            module_call_graph=[
708                ModuleCallEntry(
709                    "",
710                    ModuleCallSignature(
711                        inputs=[], outputs=[], in_spec=in_spec, out_spec=out_spec
712                    ),
713                )
714            ],
715            constants=subgraph_constants,
716            verifiers=[owning_program.verifier],
717        ),
718        toplevel_input_specs_to_delete,
719        toplevel_output_specs_to_delete,
720    )
721
722
723def create_submodule_from_nodes(
724    gm: torch.fx.GraphModule,
725    node_list: NodeList,
726    tag: str,
727    skip_legalize_graph: bool = False,
728) -> Tuple[torch.fx.GraphModule, torch.fx.Node]:
729    """
730    Modifies the given graph module in-place to separate out the given nodes
731    into a submodule. The given node_list should form a fully connected
732    subgraph.
733
734    Args:
735        gm: The graph module that we want to partition
736        node_list: A list of nodes that belong in the partition
737
738    Returns:
739        The submodule that has been partitioned, the call_module node in the
740        toplevel graph module calling the submodule
741    """
742    sorted_nodes = topo_sort(node_list)
743
744    submodule_name = "fused_" + tag
745    sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule(
746        gm, sorted_nodes, submodule_name
747    )
748
749    _fixup_output_node(sub_gm)
750
751    gm = insert_subgm(gm, sub_gm, orig_inputs, orig_outputs)
752    submodule_node = None
753    for node in gm.graph.nodes:
754        if node.op == "call_module":
755            if node.target == submodule_name:
756                submodule_node = node
757            else:
758                raise RuntimeError(
759                    f"The submodule created with nodes {node_list} did not form \
760                    one fully contained subgraph. Check that these nodes form a \
761                    fully contained graph. Partitioned graph: {gm.graph}."
762                )
763
764    if len(orig_outputs) == 1 and isinstance(orig_outputs[0].meta["val"], FakeTensor):
765        # If the original output is a single tensor, it has been
766        # pytree.tree_flatten-ed to be a singleton list, so we want to replace
767        # all uses with a getitem call to the 0th index of the result
768        with gm.graph.inserting_after(submodule_node):
769            proxy_out = torch.fx.Proxy(submodule_node)[0].node  # type: ignore[index]
770            submodule_node.replace_all_uses_with(proxy_out)
771            proxy_out.meta["val"] = submodule_node.meta["val"]
772            # Reset the args since it was overwritten in the previous line
773            proxy_out.args = (submodule_node, 0)
774    else:
775        # fuse_as_graphmodule will automatically propagate the metadata of the
776        # partition's last node to the getitem nodes that appear after the
777        # call_module node. However, in the case of delegation we do not want
778        # these getitem nodes to contain irrelevant previous metadata
779        # (ex. source_fn, # nn_module_stack)
780        for user_node in submodule_node.users:
781            user_node.meta.pop("nn_module_stack", None)
782            user_node.meta.pop("source_fn_stack", None)
783
784    erase_nodes(gm, sorted_nodes)
785
786    # Topological sort original gm with newly created sub_gm
787    # TODO : T153794167 Get rid of support for skipping legalize graph in create_submodule_from_nodes
788    # once we transition to using fuse_by_partitions.
789    if not skip_legalize_graph:
790        legalize_graph(gm)
791
792    # Get the call_module node
793    submodule_node = None
794    for node in gm.graph.nodes:
795        if node.op == "call_module" and node.target == submodule_name:
796            submodule_node = node
797        elif node.op == "call_module":
798            raise RuntimeError(
799                f"The submodule created with nodes {node_list} did not form \
800                one fully contained subgraph. Check that these nodes form a \
801                fully contained graph. Partitioned graph: {gm.graph}."
802            )
803
804    assert (
805        submodule_node is not None
806    ), f"No submodule was created with the nodes {node_list} in the graph {gm.graph}"
807
808    return sub_gm, submodule_node
809
810
811def get_lowered_submodules(
812    graph_module: torch.fx.GraphModule,
813) -> List[Tuple[str, LoweredBackendModule, torch.fx.Node]]:
814    """
815    Returns a list of lowered modules that are in the given graph (does not look
816    into submodules). Specifically, the returned value is a list containing a
817    tuple of (name of the lowered module that's stored in the graph module, the
818    lowered module itself, and the fx node that called this lowered module).
819    """
820    lowered_submodules = []
821    for node in graph_module.graph.nodes:
822        if node.op == "call_function" and node.target == executorch_call_delegate:
823            name, module, node = _get_submodule(graph_module, node, 0)
824            assert isinstance(module, LoweredBackendModule)
825            lowered_submodules.append((name, module, node))
826    return lowered_submodules
827
828
829def get_lowered_backend_modules(
830    graph_module: torch.fx.GraphModule,
831) -> List[LoweredBackendModule]:
832    """
833    Returns a list of exported programs which were lowered by backen delegates
834    """
835    lowered_programs = []
836    for node in graph_module.graph.nodes:
837        if node.op == "call_function" and node.target == executorch_call_delegate:
838            lowered_backend_module = getattr(graph_module, node.args[0].name)
839            lowered_programs.append(lowered_backend_module)
840
841    return lowered_programs
842
843
844def _unsafe_adjust_original_program(  # noqa: C901
845    original_program: ExportedProgram,
846    call_delegate_node: torch.fx.Node,
847    input_specs_to_delete: Dict[str, InputSpec],
848    output_specs_to_delete: Dict[str, OutputSpec],
849) -> None:
850    """
851    Directly modify the original exported program's signature and state dict
852    based on the consumed params/buffers in the delegate.
853    """
854    original_program._graph_signature.input_specs = [
855        input_spec
856        for input_spec in original_program.graph_signature.input_specs
857        if input_spec.arg.name not in input_specs_to_delete
858    ]
859
860    currently_used_targets: Set[str] = {
861        input_spec.target
862        for input_spec in original_program._graph_signature.input_specs
863        if input_spec.target is not None
864    }
865
866    original_program._graph_signature.output_specs = [
867        output_spec
868        for output_spec in original_program.graph_signature.output_specs
869        if output_spec.arg.name not in output_specs_to_delete
870    ]
871
872    # Delete all parameters/buffers consumed by the created exported program
873    # from the graph signature, state dict, constants table
874    for node in original_program.graph.nodes:
875        if node.op == "placeholder":
876            if node.name in input_specs_to_delete:
877                assert len(node.users) == 0
878                original_program.graph.erase_node(node)
879        else:
880            break
881
882    for input_spec in input_specs_to_delete.values():
883        input_target = input_spec.target
884        assert input_target is not None
885
886        if input_target in currently_used_targets:
887            continue
888
889        if input_spec.kind == InputKind.PARAMETER:
890            del original_program._state_dict[input_target]
891        elif input_spec.kind == InputKind.BUFFER:
892            if input_spec.persistent:
893                del original_program._state_dict[input_target]
894            else:
895                del original_program._constants[input_spec.target]
896        elif input_spec.kind == InputKind.CONSTANT_TENSOR:
897            del original_program._constants[input_spec.target]
898        else:
899            raise RuntimeError(f"Invalid input spec {input_spec} received")
900
901    # Delete buffer mutations from the output which were consumed by the delegate
902    toplevel_output_node = None
903    for node in reversed(original_program.graph.nodes):
904        if node.op == "output":
905            toplevel_output_node = node
906            break
907
908    assert toplevel_output_node is not None
909    assert (
910        len(toplevel_output_node.args) == 1
911    ), f"Invalid output node: {toplevel_output_node} with args {toplevel_output_node.args}"
912
913    new_output_args = [
914        arg
915        for arg in toplevel_output_node.args[0]
916        if not isinstance(arg, torch.fx.Node) or arg.name not in output_specs_to_delete
917    ]
918    toplevel_output_node.args = (tuple(new_output_args),)
919
920    # Delete the buffer mutation getitem nodes
921    getitem_idxs: List[int] = []
922    user_nodes = list(call_delegate_node.users.keys())
923    for user in user_nodes:
924        if user.name in output_specs_to_delete:
925            assert (
926                user.op == "call_function" and user.target == operator.getitem
927            ), f"Invalid user {user}, node.op is {node.op} and node.target is {node.target}"
928            user_idx = user.args[1]
929            assert isinstance(user_idx, int), f"Invalid getitem type: {type(user_idx)}"
930            getitem_idxs.append(user_idx)
931            original_program.graph.erase_node(user)
932
933    getitem_idxs.sort(reverse=True)
934
935    # Adjust all the getitem indices after the deleted getitems
936    user_nodes = list(call_delegate_node.users.keys())
937    for user in user_nodes:
938        assert user.op == "call_function" and user.target == operator.getitem
939        user_idx = user.args[1]
940        assert isinstance(user_idx, int)
941        for i, idx in enumerate(getitem_idxs):
942            if user_idx > idx:
943                user.args = (user.args[0], user_idx - (len(getitem_idxs) - i))
944                break
945
946    original_program._validate()
947