xref: /aosp_15_r20/external/executorch/exir/backend/utils.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Workerimport logging
8*523fa7a6SAndroid Build Coastguard Workerimport operator
9*523fa7a6SAndroid Build Coastguard Workerfrom collections import defaultdict
10*523fa7a6SAndroid Build Coastguard Workerfrom functools import lru_cache
11*523fa7a6SAndroid Build Coastguard Workerfrom typing import Dict, Iterable, List, Optional, Set, Tuple, Union
12*523fa7a6SAndroid Build Coastguard Worker
13*523fa7a6SAndroid Build Coastguard Workerimport torch
14*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.backend_details import ExportedProgram
15*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.canonical_partitioners.duplicate_constant_node_pass import (
16*523fa7a6SAndroid Build Coastguard Worker    duplicate_constant_node,
17*523fa7a6SAndroid Build Coastguard Worker)
18*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.common import setting_python_recursive_limit
19*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.delegate import executorch_call_delegate
20*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects._ops import ops as exir_ops
21*523fa7a6SAndroid Build Coastguard Worker
22*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.lowered_backend_module import create_submodule_from_nodes
23*523fa7a6SAndroid Build Coastguard Workerfrom torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
24*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx.node import Node
25*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx.passes.utils.source_matcher_utils import SourcePartition
26*523fa7a6SAndroid Build Coastguard Worker
27*523fa7a6SAndroid Build Coastguard WorkerT_QuantPerTensor = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
28*523fa7a6SAndroid Build Coastguard WorkerT_DQuantPerTensor = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
29*523fa7a6SAndroid Build Coastguard Worker
30*523fa7a6SAndroid Build Coastguard Worker
31*523fa7a6SAndroid Build Coastguard Worker# NB: Set this to None to handle validation from MobileBert
32*523fa7a6SAndroid Build Coastguard Worker@lru_cache(maxsize=None)
33*523fa7a6SAndroid Build Coastguard Workerdef is_same_node(
34*523fa7a6SAndroid Build Coastguard Worker    node_left: Iterable[torch.fx.Node],
35*523fa7a6SAndroid Build Coastguard Worker    node_right: Iterable[torch.fx.Node],
36*523fa7a6SAndroid Build Coastguard Worker) -> bool:
37*523fa7a6SAndroid Build Coastguard Worker    # two nodes are the same if they have the same target and op
38*523fa7a6SAndroid Build Coastguard Worker    # same for their args
39*523fa7a6SAndroid Build Coastguard Worker    if isinstance(node_left, torch.fx.Node) and isinstance(node_right, torch.fx.Node):
40*523fa7a6SAndroid Build Coastguard Worker        if not (
41*523fa7a6SAndroid Build Coastguard Worker            (node_left.target == node_right.target)
42*523fa7a6SAndroid Build Coastguard Worker            and (node_left.op == node_right.op)
43*523fa7a6SAndroid Build Coastguard Worker            and (len(node_left.all_input_nodes) == len(node_right.all_input_nodes))
44*523fa7a6SAndroid Build Coastguard Worker            and all(
45*523fa7a6SAndroid Build Coastguard Worker                is_same_node(arg_left, arg_right)
46*523fa7a6SAndroid Build Coastguard Worker                for arg_left, arg_right in zip(
47*523fa7a6SAndroid Build Coastguard Worker                    node_left.all_input_nodes, node_right.all_input_nodes
48*523fa7a6SAndroid Build Coastguard Worker                )
49*523fa7a6SAndroid Build Coastguard Worker            )
50*523fa7a6SAndroid Build Coastguard Worker        ):
51*523fa7a6SAndroid Build Coastguard Worker            return False
52*523fa7a6SAndroid Build Coastguard Worker    else:
53*523fa7a6SAndroid Build Coastguard Worker        if len(list(node_left)) != len(list(node_right)):
54*523fa7a6SAndroid Build Coastguard Worker            return False
55*523fa7a6SAndroid Build Coastguard Worker        for n_left, n_right in zip(node_left, node_right):
56*523fa7a6SAndroid Build Coastguard Worker            if not is_same_node(n_left, n_right):
57*523fa7a6SAndroid Build Coastguard Worker                return False
58*523fa7a6SAndroid Build Coastguard Worker    return True
59*523fa7a6SAndroid Build Coastguard Worker
60*523fa7a6SAndroid Build Coastguard Worker
61*523fa7a6SAndroid Build Coastguard Workerdef is_identical_graph(
62*523fa7a6SAndroid Build Coastguard Worker    graph_left: torch.fx.GraphModule, graph_right: torch.fx.GraphModule
63*523fa7a6SAndroid Build Coastguard Worker) -> bool:
64*523fa7a6SAndroid Build Coastguard Worker    # two graph are the same if they have the same nodes and op. The order of nodes also
65*523fa7a6SAndroid Build Coastguard Worker    # matters in this function is more strict. Two graph are not considered as the same
66*523fa7a6SAndroid Build Coastguard Worker    # if the topological order of the nodes is the same in this function but the order of nodes
67*523fa7a6SAndroid Build Coastguard Worker    # is not the same.
68*523fa7a6SAndroid Build Coastguard Worker    if len(list(graph_left.graph.nodes)) != len(list(graph_right.graph.nodes)):
69*523fa7a6SAndroid Build Coastguard Worker        return False
70*523fa7a6SAndroid Build Coastguard Worker    with setting_python_recursive_limit(30000):
71*523fa7a6SAndroid Build Coastguard Worker        for node_left, node_right in zip(
72*523fa7a6SAndroid Build Coastguard Worker            graph_left.graph.nodes, graph_right.graph.nodes
73*523fa7a6SAndroid Build Coastguard Worker        ):
74*523fa7a6SAndroid Build Coastguard Worker            if not (is_same_node(node_left, node_right)):
75*523fa7a6SAndroid Build Coastguard Worker                return False
76*523fa7a6SAndroid Build Coastguard Worker    return True
77*523fa7a6SAndroid Build Coastguard Worker
78*523fa7a6SAndroid Build Coastguard Worker
79*523fa7a6SAndroid Build Coastguard Workerdef remove_first_quant_and_last_dequant(
80*523fa7a6SAndroid Build Coastguard Worker    graph_module: torch.fx.GraphModule,
81*523fa7a6SAndroid Build Coastguard Worker) -> None:
82*523fa7a6SAndroid Build Coastguard Worker    for node in graph_module.graph.nodes:
83*523fa7a6SAndroid Build Coastguard Worker        if node.target == T_QuantPerTensor:
84*523fa7a6SAndroid Build Coastguard Worker            if node.args[0].op == "placeholder":
85*523fa7a6SAndroid Build Coastguard Worker                node_users = list(node.users.keys())
86*523fa7a6SAndroid Build Coastguard Worker                for dequant_node in node_users:
87*523fa7a6SAndroid Build Coastguard Worker                    # point the dequant arg to the placeholder
88*523fa7a6SAndroid Build Coastguard Worker                    dequant_node.args = (node.args[0],) + dequant_node.args[1:]
89*523fa7a6SAndroid Build Coastguard Worker        elif node.target == T_DQuantPerTensor:
90*523fa7a6SAndroid Build Coastguard Worker            node_users = list(node.users.keys())
91*523fa7a6SAndroid Build Coastguard Worker            if node_users[0].op == "output":
92*523fa7a6SAndroid Build Coastguard Worker                # point the output arg to the quant node
93*523fa7a6SAndroid Build Coastguard Worker                output_node = node_users[0]
94*523fa7a6SAndroid Build Coastguard Worker                output_node.args = ([node.args[0]],)
95*523fa7a6SAndroid Build Coastguard Worker    # Remove the quant/dequant nodes as they don't have users
96*523fa7a6SAndroid Build Coastguard Worker    graph_module.graph.eliminate_dead_code()
97*523fa7a6SAndroid Build Coastguard Worker    graph_module.recompile()
98*523fa7a6SAndroid Build Coastguard Worker
99*523fa7a6SAndroid Build Coastguard Worker
100*523fa7a6SAndroid Build Coastguard Worker# TODO - use edge ops
101*523fa7a6SAndroid Build Coastguard Workerdef replace_quantized_partition_with_op(
102*523fa7a6SAndroid Build Coastguard Worker    graph_module: torch.fx.GraphModule,
103*523fa7a6SAndroid Build Coastguard Worker    partition: SourcePartition,
104*523fa7a6SAndroid Build Coastguard Worker    replacement_op: torch._ops.OpOverloadPacket,
105*523fa7a6SAndroid Build Coastguard Worker) -> Tuple[torch.fx.Node, List[torch.fx.Node], List[torch.fx.Node]]:
106*523fa7a6SAndroid Build Coastguard Worker    """
107*523fa7a6SAndroid Build Coastguard Worker    Replaces partition with the op specified by replacement_op. It's also expected that
108*523fa7a6SAndroid Build Coastguard Worker    the nodes contained in partition are sourced from a quantized module as this function
109*523fa7a6SAndroid Build Coastguard Worker    searches for the quantization pattern to consume along with the nodes in the partition,
110*523fa7a6SAndroid Build Coastguard Worker    to be then replaced by replacement_op.
111*523fa7a6SAndroid Build Coastguard Worker
112*523fa7a6SAndroid Build Coastguard Worker    Args:
113*523fa7a6SAndroid Build Coastguard Worker        graph_module: The graph module from which this partition was sourced.
114*523fa7a6SAndroid Build Coastguard Worker        partition: Partition to be replaced.
115*523fa7a6SAndroid Build Coastguard Worker        replacement_op: The op to replace paritition with.
116*523fa7a6SAndroid Build Coastguard Worker    Returns:
117*523fa7a6SAndroid Build Coastguard Worker        Tuple: First element in the tuple is the new replaced module. The second and third
118*523fa7a6SAndroid Build Coastguard Worker        node lists in the returned tuple consist of the dq and q nodes that were consumed
119*523fa7a6SAndroid Build Coastguard Worker        along with this partition to be replaced by the replacement_op.
120*523fa7a6SAndroid Build Coastguard Worker    """
121*523fa7a6SAndroid Build Coastguard Worker
122*523fa7a6SAndroid Build Coastguard Worker    dequant_nodes = []
123*523fa7a6SAndroid Build Coastguard Worker    quant_nodes = []
124*523fa7a6SAndroid Build Coastguard Worker    input_nodes = []
125*523fa7a6SAndroid Build Coastguard Worker    output_nodes = []
126*523fa7a6SAndroid Build Coastguard Worker
127*523fa7a6SAndroid Build Coastguard Worker    partition_nodes = [node for node in partition.nodes if node not in partition.params]
128*523fa7a6SAndroid Build Coastguard Worker
129*523fa7a6SAndroid Build Coastguard Worker    # We recreate our input nodes and output nodes list instead of using partition.input_nodes
130*523fa7a6SAndroid Build Coastguard Worker    # and partition.output_nodes as the ordering of the nodes in those lists is not deterministic,
131*523fa7a6SAndroid Build Coastguard Worker    # whereas for the quant fusion pass we expect deterministic ordering.
132*523fa7a6SAndroid Build Coastguard Worker    for node in partition.nodes:
133*523fa7a6SAndroid Build Coastguard Worker        for arg in node.args:
134*523fa7a6SAndroid Build Coastguard Worker            if isinstance(arg, torch.fx.Node) and (arg not in partition.nodes):
135*523fa7a6SAndroid Build Coastguard Worker                input_nodes.append(arg)
136*523fa7a6SAndroid Build Coastguard Worker
137*523fa7a6SAndroid Build Coastguard Worker        for user in node.users.keys():
138*523fa7a6SAndroid Build Coastguard Worker            if user not in partition.nodes:
139*523fa7a6SAndroid Build Coastguard Worker                output_nodes.append(node)
140*523fa7a6SAndroid Build Coastguard Worker
141*523fa7a6SAndroid Build Coastguard Worker    # Try to find all the dq nodes that are feeding into this module.
142*523fa7a6SAndroid Build Coastguard Worker    for node in input_nodes:
143*523fa7a6SAndroid Build Coastguard Worker        if node.target == T_DQuantPerTensor:
144*523fa7a6SAndroid Build Coastguard Worker            dequant_nodes += [node]
145*523fa7a6SAndroid Build Coastguard Worker
146*523fa7a6SAndroid Build Coastguard Worker    # Try to find all the q nodes that this module is feeding out into.
147*523fa7a6SAndroid Build Coastguard Worker    for node in output_nodes:
148*523fa7a6SAndroid Build Coastguard Worker        for user in node.users.keys():
149*523fa7a6SAndroid Build Coastguard Worker            if user.target == T_QuantPerTensor:
150*523fa7a6SAndroid Build Coastguard Worker                quant_nodes += [user]
151*523fa7a6SAndroid Build Coastguard Worker
152*523fa7a6SAndroid Build Coastguard Worker    assert len(dequant_nodes) >= 1, "Dequant nodes missing in node list to be replaced."
153*523fa7a6SAndroid Build Coastguard Worker    assert len(quant_nodes) >= 1, "Quant nodes missing in node list to be replaced."
154*523fa7a6SAndroid Build Coastguard Worker
155*523fa7a6SAndroid Build Coastguard Worker    # After this, node list will essentially contain all the nodes in the
156*523fa7a6SAndroid Build Coastguard Worker    # dq->op->q pattern that we will want to replace with a custom backend op.
157*523fa7a6SAndroid Build Coastguard Worker    node_list = dequant_nodes + partition_nodes + quant_nodes
158*523fa7a6SAndroid Build Coastguard Worker
159*523fa7a6SAndroid Build Coastguard Worker    submodule, call_module_node = create_submodule_from_nodes(
160*523fa7a6SAndroid Build Coastguard Worker        graph_module, node_list, "to_be_replaced", skip_legalize_graph=True
161*523fa7a6SAndroid Build Coastguard Worker    )
162*523fa7a6SAndroid Build Coastguard Worker
163*523fa7a6SAndroid Build Coastguard Worker    # Update the replaced op so that we have all the latest args and kwargs.
164*523fa7a6SAndroid Build Coastguard Worker    with graph_module.graph.inserting_before(call_module_node):
165*523fa7a6SAndroid Build Coastguard Worker        replaced_op = graph_module.graph.call_function(
166*523fa7a6SAndroid Build Coastguard Worker            replacement_op,
167*523fa7a6SAndroid Build Coastguard Worker            call_module_node.args,
168*523fa7a6SAndroid Build Coastguard Worker            kwargs=call_module_node.kwargs,
169*523fa7a6SAndroid Build Coastguard Worker        )
170*523fa7a6SAndroid Build Coastguard Worker        call_module_node.replace_all_uses_with(replaced_op)
171*523fa7a6SAndroid Build Coastguard Worker        graph_module.graph.erase_node(call_module_node)
172*523fa7a6SAndroid Build Coastguard Worker        replaced_op.meta = call_module_node.meta
173*523fa7a6SAndroid Build Coastguard Worker    graph_module.recompile()
174*523fa7a6SAndroid Build Coastguard Worker
175*523fa7a6SAndroid Build Coastguard Worker    return (replaced_op, dequant_nodes, quant_nodes)
176*523fa7a6SAndroid Build Coastguard Worker
177*523fa7a6SAndroid Build Coastguard Worker
178*523fa7a6SAndroid Build Coastguard Workerdef _assign_new_tag(
179*523fa7a6SAndroid Build Coastguard Worker    tagged_exported_program: ExportedProgram,
180*523fa7a6SAndroid Build Coastguard Worker    copied_nodes: Set[str],
181*523fa7a6SAndroid Build Coastguard Worker):
182*523fa7a6SAndroid Build Coastguard Worker    """
183*523fa7a6SAndroid Build Coastguard Worker    Assign new tag to the copied nodes.
184*523fa7a6SAndroid Build Coastguard Worker
185*523fa7a6SAndroid Build Coastguard Worker    Before the pass
186*523fa7a6SAndroid Build Coastguard Worker    constant_0 (tag_10) ------------------> op_b (tag_10)
187*523fa7a6SAndroid Build Coastguard Worker    constant_0_copy (tag_10) -------------> op_a (tag_11)
188*523fa7a6SAndroid Build Coastguard Worker
189*523fa7a6SAndroid Build Coastguard Worker    After the pass
190*523fa7a6SAndroid Build Coastguard Worker    constant_0 (tag_10) ------------------> op_b (tag_10)
191*523fa7a6SAndroid Build Coastguard Worker    constant_0_copy (tag_11) -------------> op_a (tag_11)
192*523fa7a6SAndroid Build Coastguard Worker
193*523fa7a6SAndroid Build Coastguard Worker    """
194*523fa7a6SAndroid Build Coastguard Worker    for node in tagged_exported_program.graph.nodes:
195*523fa7a6SAndroid Build Coastguard Worker        if node.op == "placeholder":
196*523fa7a6SAndroid Build Coastguard Worker            if node.name in copied_nodes:
197*523fa7a6SAndroid Build Coastguard Worker                users_tag = set()
198*523fa7a6SAndroid Build Coastguard Worker                for user in node.users:
199*523fa7a6SAndroid Build Coastguard Worker                    users_tag.add(user.meta.get("delegation_tag", None))
200*523fa7a6SAndroid Build Coastguard Worker                # Assign the tag to the copy constant node the same as their users.
201*523fa7a6SAndroid Build Coastguard Worker                if len(users_tag) == 1:
202*523fa7a6SAndroid Build Coastguard Worker                    node.meta["delegation_tag"] = users_tag.pop()
203*523fa7a6SAndroid Build Coastguard Worker
204*523fa7a6SAndroid Build Coastguard Worker
205*523fa7a6SAndroid Build Coastguard Workerdef _maybe_duplicate_constant_nodes(
206*523fa7a6SAndroid Build Coastguard Worker    tagged_exported_program: ExportedProgram,
207*523fa7a6SAndroid Build Coastguard Worker    tag: str,
208*523fa7a6SAndroid Build Coastguard Worker) -> None:
209*523fa7a6SAndroid Build Coastguard Worker    """
210*523fa7a6SAndroid Build Coastguard Worker    If the constants node is shared by different tagged nodes, like
211*523fa7a6SAndroid Build Coastguard Worker    constant_0 ----> op_b (tag_10)
212*523fa7a6SAndroid Build Coastguard Worker    |-------------> op_a (tag_11)
213*523fa7a6SAndroid Build Coastguard Worker
214*523fa7a6SAndroid Build Coastguard Worker    we make default as constant_0 is duplicated to constant_0_1, constant_0_2, unless the node is tagged with "no_copy"
215*523fa7a6SAndroid Build Coastguard Worker    constant_0 ------------------> op_b (tag_10)
216*523fa7a6SAndroid Build Coastguard Worker    constant_0_copy -------------> op_a (tag_11)
217*523fa7a6SAndroid Build Coastguard Worker
218*523fa7a6SAndroid Build Coastguard Worker    backend can estimate how much they want to duplicate the constant node, either error out or default to duplicate
219*523fa7a6SAndroid Build Coastguard Worker    """
220*523fa7a6SAndroid Build Coastguard Worker    candidate_nodes = set()
221*523fa7a6SAndroid Build Coastguard Worker    for node in tagged_exported_program.graph.nodes:
222*523fa7a6SAndroid Build Coastguard Worker        if node.meta.get("delegation_tag", "") == tag:
223*523fa7a6SAndroid Build Coastguard Worker            if node.op == "placeholder":
224*523fa7a6SAndroid Build Coastguard Worker                for user in node.users:
225*523fa7a6SAndroid Build Coastguard Worker                    users_tag = user.meta.get("delegation_tag", None)
226*523fa7a6SAndroid Build Coastguard Worker                    if users_tag != tag:
227*523fa7a6SAndroid Build Coastguard Worker                        # If the node is tagged with "no_copy", we stop duplicating it and throw an error
228*523fa7a6SAndroid Build Coastguard Worker                        if node.meta.get("no_copy", False):
229*523fa7a6SAndroid Build Coastguard Worker                            raise RuntimeError(
230*523fa7a6SAndroid Build Coastguard Worker                                f"constant data node ({node}) is tagged with ({tag}) but has user ({user}) which has tag ({users_tag})"
231*523fa7a6SAndroid Build Coastguard Worker                            )
232*523fa7a6SAndroid Build Coastguard Worker                        else:
233*523fa7a6SAndroid Build Coastguard Worker                            candidate_nodes.add(node.name)
234*523fa7a6SAndroid Build Coastguard Worker    copied_nodes = set()
235*523fa7a6SAndroid Build Coastguard Worker    for candidate_node in candidate_nodes:
236*523fa7a6SAndroid Build Coastguard Worker        # Both tagged exported program and the owning program need to go through the same duplication pass
237*523fa7a6SAndroid Build Coastguard Worker        copied_nodes = copied_nodes.union(
238*523fa7a6SAndroid Build Coastguard Worker            duplicate_constant_node(tagged_exported_program, candidate_node)
239*523fa7a6SAndroid Build Coastguard Worker        )
240*523fa7a6SAndroid Build Coastguard Worker    candidate_node_with_copies = candidate_nodes.union(copied_nodes)
241*523fa7a6SAndroid Build Coastguard Worker    _assign_new_tag(tagged_exported_program, candidate_node_with_copies)
242*523fa7a6SAndroid Build Coastguard Worker
243*523fa7a6SAndroid Build Coastguard Worker
244*523fa7a6SAndroid Build Coastguard Workerdef _get_item_from_executorch_call_delegate(node: torch.fx.Node) -> bool:
245*523fa7a6SAndroid Build Coastguard Worker    """
246*523fa7a6SAndroid Build Coastguard Worker    Check if the node is the getitem followed by executorch_call_delegate node. These getitems node
247*523fa7a6SAndroid Build Coastguard Worker    are just for getting the result from delegate because the input/output to delegates are flattened
248*523fa7a6SAndroid Build Coastguard Worker    """
249*523fa7a6SAndroid Build Coastguard Worker    return (
250*523fa7a6SAndroid Build Coastguard Worker        node.target == operator.getitem
251*523fa7a6SAndroid Build Coastguard Worker        and len(node.args) == 2
252*523fa7a6SAndroid Build Coastguard Worker        and node.args[0].target == executorch_call_delegate  # pyre-ignore
253*523fa7a6SAndroid Build Coastguard Worker        and isinstance(node.args[1], int)
254*523fa7a6SAndroid Build Coastguard Worker    )
255*523fa7a6SAndroid Build Coastguard Worker
256*523fa7a6SAndroid Build Coastguard Worker
257*523fa7a6SAndroid Build Coastguard Workerdef get_non_lowered_nodes(graph: torch.fx.Graph) -> List[torch.fx.Node]:
258*523fa7a6SAndroid Build Coastguard Worker    """
259*523fa7a6SAndroid Build Coastguard Worker    Returns a list of non lowered nodes in the graph module.
260*523fa7a6SAndroid Build Coastguard Worker    """
261*523fa7a6SAndroid Build Coastguard Worker    return [
262*523fa7a6SAndroid Build Coastguard Worker        node
263*523fa7a6SAndroid Build Coastguard Worker        for node in graph.nodes
264*523fa7a6SAndroid Build Coastguard Worker        if node.op == "call_function"
265*523fa7a6SAndroid Build Coastguard Worker        and node.target != executorch_call_delegate
266*523fa7a6SAndroid Build Coastguard Worker        and (not _get_item_from_executorch_call_delegate(node))
267*523fa7a6SAndroid Build Coastguard Worker    ]
268*523fa7a6SAndroid Build Coastguard Worker
269*523fa7a6SAndroid Build Coastguard Worker
270*523fa7a6SAndroid Build Coastguard Workerdef get_delegates(graph: torch.fx.Graph) -> List[torch.fx.Node]:
271*523fa7a6SAndroid Build Coastguard Worker    """
272*523fa7a6SAndroid Build Coastguard Worker    Returns the list of delegates from the graph.
273*523fa7a6SAndroid Build Coastguard Worker    """
274*523fa7a6SAndroid Build Coastguard Worker    return [
275*523fa7a6SAndroid Build Coastguard Worker        node
276*523fa7a6SAndroid Build Coastguard Worker        for node in graph.nodes
277*523fa7a6SAndroid Build Coastguard Worker        if node.op == "get_attr" and node.name.startswith("lowered_module_")
278*523fa7a6SAndroid Build Coastguard Worker    ]
279*523fa7a6SAndroid Build Coastguard Worker
280*523fa7a6SAndroid Build Coastguard Worker
281*523fa7a6SAndroid Build Coastguard Workerdef print_delegated_graph(graph_module: torch.fx.GraphModule) -> None:
282*523fa7a6SAndroid Build Coastguard Worker    """
283*523fa7a6SAndroid Build Coastguard Worker    Print the formatted graph string.
284*523fa7a6SAndroid Build Coastguard Worker    """
285*523fa7a6SAndroid Build Coastguard Worker    print(format_delegated_graph(graph_module))
286*523fa7a6SAndroid Build Coastguard Worker
287*523fa7a6SAndroid Build Coastguard Worker
288*523fa7a6SAndroid Build Coastguard Workerdef format_delegated_graph(graph_module: torch.fx.GraphModule) -> str:
289*523fa7a6SAndroid Build Coastguard Worker    """
290*523fa7a6SAndroid Build Coastguard Worker    Return the formatted graph string of including lowered_module (both backend id and original graph) together with the graph module. Example output:
291*523fa7a6SAndroid Build Coastguard Worker    graph():
292*523fa7a6SAndroid Build Coastguard Worker        %arg0_1 : [num_users=2] = placeholder[target=arg0_1]
293*523fa7a6SAndroid Build Coastguard Worker        %arg1_1 : [num_users=2] = placeholder[target=arg1_1]
294*523fa7a6SAndroid Build Coastguard Worker        %arg2_1 : [num_users=2] = placeholder[target=arg2_1]
295*523fa7a6SAndroid Build Coastguard Worker        %lowered_module_0 : [num_users=1] = get_attr[target=lowered_module_0]
296*523fa7a6SAndroid Build Coastguard Worker            backend_id: BackendWithCompilerDemo
297*523fa7a6SAndroid Build Coastguard Worker            lowered graph():
298*523fa7a6SAndroid Build Coastguard Worker                %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
299*523fa7a6SAndroid Build Coastguard Worker                %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
300*523fa7a6SAndroid Build Coastguard Worker                %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
301*523fa7a6SAndroid Build Coastguard Worker                %aten_mm_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.mm.default](args = (%arg0_1, %arg1_1), kwargs = {})
302*523fa7a6SAndroid Build Coastguard Worker                %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%aten_mm_default, %arg2_1), kwargs = {})
303*523fa7a6SAndroid Build Coastguard Worker                return [aten_add_tensor]
304*523fa7a6SAndroid Build Coastguard Worker        %executorch_call_delegate : [num_users=1] = call_function[target=torch.ops.higher_order.executorch_call_delegate](args = (%lowered_module_0, %arg0_1, %arg1_1, %arg2_1), kwargs = {})
305*523fa7a6SAndroid Build Coastguard Worker        %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%executorch_call_delegate, 0), kwargs = {})
306*523fa7a6SAndroid Build Coastguard Worker        %aten_sub_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.sub.Tensor](args = (%getitem, %arg0_1), kwargs = {})
307*523fa7a6SAndroid Build Coastguard Worker        %lowered_module_1 : [num_users=1] = get_attr[target=lowered_module_1]
308*523fa7a6SAndroid Build Coastguard Worker            backend_id: BackendWithCompilerDemo
309*523fa7a6SAndroid Build Coastguard Worker            lowered graph():
310*523fa7a6SAndroid Build Coastguard Worker                %aten_sub_tensor : [num_users=1] = placeholder[target=aten_sub_tensor]
311*523fa7a6SAndroid Build Coastguard Worker                %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
312*523fa7a6SAndroid Build Coastguard Worker                %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
313*523fa7a6SAndroid Build Coastguard Worker                %aten_mm_default_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.mm.default](args = (%aten_sub_tensor, %arg1_1), kwargs = {})
314*523fa7a6SAndroid Build Coastguard Worker                %aten_add_tensor_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%aten_mm_default_1, %arg2_1), kwargs = {})
315*523fa7a6SAndroid Build Coastguard Worker                return [aten_add_tensor_1]
316*523fa7a6SAndroid Build Coastguard Worker        %executorch_call_delegate_1 : [num_users=1] = call_function[target=torch.ops.higher_order.executorch_call_delegate](args = (%lowered_module_1, %aten_sub_tensor, %arg1_1, %arg2_1), kwargs = {})
317*523fa7a6SAndroid Build Coastguard Worker        %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%executorch_call_delegate_1, 0), kwargs = {})
318*523fa7a6SAndroid Build Coastguard Worker        return [getitem_1]
319*523fa7a6SAndroid Build Coastguard Worker    """
320*523fa7a6SAndroid Build Coastguard Worker    lowered_module_dict = {
321*523fa7a6SAndroid Build Coastguard Worker        node.name: getattr(graph_module, node.name)
322*523fa7a6SAndroid Build Coastguard Worker        for node in graph_module.graph.nodes
323*523fa7a6SAndroid Build Coastguard Worker        if node.op == "get_attr" and node.name.startswith("lowered_module_")
324*523fa7a6SAndroid Build Coastguard Worker    }
325*523fa7a6SAndroid Build Coastguard Worker    indent = "  "
326*523fa7a6SAndroid Build Coastguard Worker    graph_format_str = "graph():\n"
327*523fa7a6SAndroid Build Coastguard Worker    for node in graph_module.graph.nodes:
328*523fa7a6SAndroid Build Coastguard Worker        graph_format_str += f"{indent}{node.format_node()}\n"
329*523fa7a6SAndroid Build Coastguard Worker        if node.op == "get_attr" and node.name.startswith("lowered_module_"):
330*523fa7a6SAndroid Build Coastguard Worker            lowered_module = lowered_module_dict[node.name]
331*523fa7a6SAndroid Build Coastguard Worker            graph_format_str += f"{indent * 2}backend_id: {lowered_module.backend_id}\n"
332*523fa7a6SAndroid Build Coastguard Worker            graph_format_str += f"{indent * 2}lowered graph():\n"
333*523fa7a6SAndroid Build Coastguard Worker            for node_in_lowered_module in lowered_module.original_module.graph.nodes:
334*523fa7a6SAndroid Build Coastguard Worker                graph_format_str += (
335*523fa7a6SAndroid Build Coastguard Worker                    f"{indent * 3}{node_in_lowered_module.format_node()}\n"
336*523fa7a6SAndroid Build Coastguard Worker                )
337*523fa7a6SAndroid Build Coastguard Worker    return graph_format_str
338*523fa7a6SAndroid Build Coastguard Worker
339*523fa7a6SAndroid Build Coastguard Worker
340*523fa7a6SAndroid Build Coastguard Workerdef tag_constant_data(edge_program: ExportedProgram) -> None:
341*523fa7a6SAndroid Build Coastguard Worker    """
342*523fa7a6SAndroid Build Coastguard Worker    Util function for partitioners. This function tags the const/param/buffers nodes
343*523fa7a6SAndroid Build Coastguard Worker    whose users all belong within the same partition. This should be called after tagging all other nodes.
344*523fa7a6SAndroid Build Coastguard Worker    Any const/param/buffer which is used as input to a subgraph, will be tagged with the same tag as that
345*523fa7a6SAndroid Build Coastguard Worker    subgraph. Throw error when const/param/buffers is used across different partitions. That is the
346*523fa7a6SAndroid Build Coastguard Worker    underlying data will be owned by multiple delegates.
347*523fa7a6SAndroid Build Coastguard Worker    """
348*523fa7a6SAndroid Build Coastguard Worker    mutated_buffer = set()
349*523fa7a6SAndroid Build Coastguard Worker    for node in edge_program.graph.nodes:
350*523fa7a6SAndroid Build Coastguard Worker        if node.op == "placeholder" and (
351*523fa7a6SAndroid Build Coastguard Worker            is_param(edge_program, node)
352*523fa7a6SAndroid Build Coastguard Worker            or is_buffer(edge_program, node)
353*523fa7a6SAndroid Build Coastguard Worker            or is_lifted_tensor_constant(edge_program, node)
354*523fa7a6SAndroid Build Coastguard Worker        ):
355*523fa7a6SAndroid Build Coastguard Worker            for node_user in node.users:
356*523fa7a6SAndroid Build Coastguard Worker                if node_user.name in edge_program.graph_signature.buffers_to_mutate:
357*523fa7a6SAndroid Build Coastguard Worker                    logging.info(
358*523fa7a6SAndroid Build Coastguard Worker                        "The buffer node is a mutated buffer node, which is not constant."
359*523fa7a6SAndroid Build Coastguard Worker                    )
360*523fa7a6SAndroid Build Coastguard Worker                    mutated_buffer.add(node)
361*523fa7a6SAndroid Build Coastguard Worker
362*523fa7a6SAndroid Build Coastguard Worker    for node in edge_program.graph.nodes:
363*523fa7a6SAndroid Build Coastguard Worker        # go through const/param/buffer nodes, if all users of const/param/buffer nodes are partitioned then partition
364*523fa7a6SAndroid Build Coastguard Worker        if node.op == "placeholder" and (
365*523fa7a6SAndroid Build Coastguard Worker            is_param(edge_program, node)
366*523fa7a6SAndroid Build Coastguard Worker            or is_buffer(edge_program, node)
367*523fa7a6SAndroid Build Coastguard Worker            or is_lifted_tensor_constant(edge_program, node)
368*523fa7a6SAndroid Build Coastguard Worker        ):
369*523fa7a6SAndroid Build Coastguard Worker            if node not in mutated_buffer:
370*523fa7a6SAndroid Build Coastguard Worker                user_tags = set()
371*523fa7a6SAndroid Build Coastguard Worker                for user in node.users:
372*523fa7a6SAndroid Build Coastguard Worker                    user_tag = user.meta.get("delegation_tag", None)
373*523fa7a6SAndroid Build Coastguard Worker                    if user_tag is not None:
374*523fa7a6SAndroid Build Coastguard Worker                        user_tags.add(user_tag)
375*523fa7a6SAndroid Build Coastguard Worker                if len(user_tags) > 1:
376*523fa7a6SAndroid Build Coastguard Worker                    logging.info(
377*523fa7a6SAndroid Build Coastguard Worker                        f"The data node is used across multiple partitions, including {user_tags}. "
378*523fa7a6SAndroid Build Coastguard Worker                        "If the data is too large and it's not preferred to copy, please tag the "
379*523fa7a6SAndroid Build Coastguard Worker                        "constant node like node.['no_copy'] = True and they won't be copied."
380*523fa7a6SAndroid Build Coastguard Worker                    )
381*523fa7a6SAndroid Build Coastguard Worker                # tag the data node with the same tag as the last user
382*523fa7a6SAndroid Build Coastguard Worker                if len(user_tags) > 0:
383*523fa7a6SAndroid Build Coastguard Worker                    node.meta["delegation_tag"] = user_tags.pop()
384*523fa7a6SAndroid Build Coastguard Worker
385*523fa7a6SAndroid Build Coastguard Worker
386*523fa7a6SAndroid Build Coastguard Workerdef tag_mutated_buffer(edge_program: ExportedProgram) -> None:
387*523fa7a6SAndroid Build Coastguard Worker    """
388*523fa7a6SAndroid Build Coastguard Worker    Util function for partitioners. This function tags the mutated buffer nodes
389*523fa7a6SAndroid Build Coastguard Worker    whose users all belong within the same partition. This should be called after tagging all other nodes.
390*523fa7a6SAndroid Build Coastguard Worker    Any buffer which is used as input to a subgraph, will be tagged with the same tag as that
391*523fa7a6SAndroid Build Coastguard Worker    subgraph. Throw error when buffers is used across different partitions. That is the
392*523fa7a6SAndroid Build Coastguard Worker    underlying data will be owned by multiple delegates.
393*523fa7a6SAndroid Build Coastguard Worker    """
394*523fa7a6SAndroid Build Coastguard Worker    for node in edge_program.graph.nodes:
395*523fa7a6SAndroid Build Coastguard Worker        # Determine whether this node is a mutated buffer
396*523fa7a6SAndroid Build Coastguard Worker        is_mutated_buffer_node = False
397*523fa7a6SAndroid Build Coastguard Worker        if node.op == "placeholder" and is_buffer(edge_program, node):
398*523fa7a6SAndroid Build Coastguard Worker            for node_user in node.users:
399*523fa7a6SAndroid Build Coastguard Worker                if node_user.name in edge_program.graph_signature.buffers_to_mutate:
400*523fa7a6SAndroid Build Coastguard Worker                    is_mutated_buffer_node = True
401*523fa7a6SAndroid Build Coastguard Worker                    break
402*523fa7a6SAndroid Build Coastguard Worker        # This node is mutated buffer, tag it
403*523fa7a6SAndroid Build Coastguard Worker        if is_mutated_buffer_node:
404*523fa7a6SAndroid Build Coastguard Worker            user_tags = set()
405*523fa7a6SAndroid Build Coastguard Worker            for user in node.users:
406*523fa7a6SAndroid Build Coastguard Worker                user_tag = user.meta.get("delegation_tag", None)
407*523fa7a6SAndroid Build Coastguard Worker                if user_tag is not None:
408*523fa7a6SAndroid Build Coastguard Worker                    user_tags.add(user_tag)
409*523fa7a6SAndroid Build Coastguard Worker            if len(user_tags) > 1:
410*523fa7a6SAndroid Build Coastguard Worker                logging.info(
411*523fa7a6SAndroid Build Coastguard Worker                    f"The data node is used across multiple partitions, including {user_tags}. "
412*523fa7a6SAndroid Build Coastguard Worker                    "If the data is too large and it's not preferred to copy, please tag the "
413*523fa7a6SAndroid Build Coastguard Worker                    "constant node like node.['no_copy'] = True and they won't be copied."
414*523fa7a6SAndroid Build Coastguard Worker                )
415*523fa7a6SAndroid Build Coastguard Worker            # tag the data node with the same tag as the last user
416*523fa7a6SAndroid Build Coastguard Worker            if len(user_tags) > 0:
417*523fa7a6SAndroid Build Coastguard Worker                node.meta["delegation_tag"] = user_tags.pop()
418*523fa7a6SAndroid Build Coastguard Worker
419*523fa7a6SAndroid Build Coastguard Worker
420*523fa7a6SAndroid Build Coastguard Worker# TODO - style: use templated types
421*523fa7a6SAndroid Build Coastguard Workerclass DelegateMappingBuilder:
422*523fa7a6SAndroid Build Coastguard Worker    """
423*523fa7a6SAndroid Build Coastguard Worker    Profiling helper class for building Delegate Mappings.
424*523fa7a6SAndroid Build Coastguard Worker    Delegate Mappings are mappings from delegate debug identifiers to node
425*523fa7a6SAndroid Build Coastguard Worker    debug handles. Specifically this is used to log within backend delegates
426*523fa7a6SAndroid Build Coastguard Worker
427*523fa7a6SAndroid Build Coastguard Worker    Args:
428*523fa7a6SAndroid Build Coastguard Worker        generated_identifiers (bool, optional): Whether identifier keys are
429*523fa7a6SAndroid Build Coastguard Worker            generated automatically. Defaults to False.
430*523fa7a6SAndroid Build Coastguard Worker    """
431*523fa7a6SAndroid Build Coastguard Worker
432*523fa7a6SAndroid Build Coastguard Worker    def __init__(self, generated_identifiers: bool = False):
433*523fa7a6SAndroid Build Coastguard Worker        self._generated_identifiers = generated_identifiers
434*523fa7a6SAndroid Build Coastguard Worker
435*523fa7a6SAndroid Build Coastguard Worker        # Note that the internal struct has a Set value, while the getter
436*523fa7a6SAndroid Build Coastguard Worker        # function returns the values as a tuple
437*523fa7a6SAndroid Build Coastguard Worker        self._debug_handle_map: Union[Dict[int, Set[int]], Dict[str, Set[int]]] = (
438*523fa7a6SAndroid Build Coastguard Worker            defaultdict(set)
439*523fa7a6SAndroid Build Coastguard Worker        )
440*523fa7a6SAndroid Build Coastguard Worker        self._next_index: int = 0
441*523fa7a6SAndroid Build Coastguard Worker
442*523fa7a6SAndroid Build Coastguard Worker    def get_delegate_mapping(
443*523fa7a6SAndroid Build Coastguard Worker        self,
444*523fa7a6SAndroid Build Coastguard Worker    ) -> Union[Dict[int, Tuple[int]], Dict[str, Tuple[int]]]:
445*523fa7a6SAndroid Build Coastguard Worker        """
446*523fa7a6SAndroid Build Coastguard Worker        Returns:
447*523fa7a6SAndroid Build Coastguard Worker           Union[Dict[int, Tuple[int]], Dict[str, Tuple[int]]]:
448*523fa7a6SAndroid Build Coastguard Worker                A map of delegate debug identifier to a list of debug handles
449*523fa7a6SAndroid Build Coastguard Worker                The keys (identifier) are either integers or strings
450*523fa7a6SAndroid Build Coastguard Worker                The values are a sorted tuple of integer debug handles
451*523fa7a6SAndroid Build Coastguard Worker        """
452*523fa7a6SAndroid Build Coastguard Worker        # pyre-ignore Warning between Union[Dict[K, V], Dict[K2, V]] vs Dict[Union[K, K2], V]
453*523fa7a6SAndroid Build Coastguard Worker        return {k: tuple(sorted(v)) for k, v in self._debug_handle_map.items()}
454*523fa7a6SAndroid Build Coastguard Worker
455*523fa7a6SAndroid Build Coastguard Worker    def insert_delegate_mapping_entry(
456*523fa7a6SAndroid Build Coastguard Worker        self,
457*523fa7a6SAndroid Build Coastguard Worker        nodes: Optional[Union[Node, List[Node]]] = None,
458*523fa7a6SAndroid Build Coastguard Worker        handles: Optional[Union[int, List[Optional[int]]]] = None,
459*523fa7a6SAndroid Build Coastguard Worker        identifier: Optional[Union[int, str]] = None,
460*523fa7a6SAndroid Build Coastguard Worker    ) -> Union[int, str]:
461*523fa7a6SAndroid Build Coastguard Worker        """
462*523fa7a6SAndroid Build Coastguard Worker        Add a new delegate mapping entry
463*523fa7a6SAndroid Build Coastguard Worker
464*523fa7a6SAndroid Build Coastguard Worker        If self._generated_identifiers = False:
465*523fa7a6SAndroid Build Coastguard Worker            - A new identifier must be provided, else an exception is thrown
466*523fa7a6SAndroid Build Coastguard Worker
467*523fa7a6SAndroid Build Coastguard Worker        If self._generated_identifiers = True:
468*523fa7a6SAndroid Build Coastguard Worker            - New identifiers are generated incrementally, 0 indexed
469*523fa7a6SAndroid Build Coastguard Worker            - Identifiers cannot be manually provided, else an exception is thrown
470*523fa7a6SAndroid Build Coastguard Worker
471*523fa7a6SAndroid Build Coastguard Worker        Args:
472*523fa7a6SAndroid Build Coastguard Worker            nodes (Union[Node, List[Node]]): A (list of) Node(s)
473*523fa7a6SAndroid Build Coastguard Worker            handles (Union[int, List[Optional[int]]]): A (list of) debug handle(s)
474*523fa7a6SAndroid Build Coastguard Worker            identifier (Optional[Union[int, str]]):
475*523fa7a6SAndroid Build Coastguard Worker                Debug identifier corresponding to the Node(s)
476*523fa7a6SAndroid Build Coastguard Worker
477*523fa7a6SAndroid Build Coastguard Worker        Note: Exactly one of nodes and handles must be provided
478*523fa7a6SAndroid Build Coastguard Worker        Note: If a debug handle is missing or None, it is skipped
479*523fa7a6SAndroid Build Coastguard Worker
480*523fa7a6SAndroid Build Coastguard Worker        Returns:
481*523fa7a6SAndroid Build Coastguard Worker            Union[int, str]:
482*523fa7a6SAndroid Build Coastguard Worker                Delegate debug identifier inserted
483*523fa7a6SAndroid Build Coastguard Worker        """
484*523fa7a6SAndroid Build Coastguard Worker
485*523fa7a6SAndroid Build Coastguard Worker        # Check for manual addition of identifier (with generated identifiers enabled)
486*523fa7a6SAndroid Build Coastguard Worker        if self._generated_identifiers and identifier is not None:
487*523fa7a6SAndroid Build Coastguard Worker            raise Exception(
488*523fa7a6SAndroid Build Coastguard Worker                f"Builders using generated identifiers can't manually add identifiers: {identifier}. Failed to add or update entry"
489*523fa7a6SAndroid Build Coastguard Worker            )
490*523fa7a6SAndroid Build Coastguard Worker
491*523fa7a6SAndroid Build Coastguard Worker        if identifier is not None and identifier in self._debug_handle_map:
492*523fa7a6SAndroid Build Coastguard Worker            raise Exception(
493*523fa7a6SAndroid Build Coastguard Worker                "This delegate debug identifier was already inserted. Duplicate delegate debug identifiers are not allowed."
494*523fa7a6SAndroid Build Coastguard Worker            )
495*523fa7a6SAndroid Build Coastguard Worker
496*523fa7a6SAndroid Build Coastguard Worker        # Check for exactly one of nodes and handles being populated
497*523fa7a6SAndroid Build Coastguard Worker        if not ((nodes is not None) ^ (handles is not None)):
498*523fa7a6SAndroid Build Coastguard Worker            raise Exception(
499*523fa7a6SAndroid Build Coastguard Worker                "Only one of nodes or handles must be provided. Either both were provided or neither were provided. Failed to add or update entry."
500*523fa7a6SAndroid Build Coastguard Worker            )
501*523fa7a6SAndroid Build Coastguard Worker
502*523fa7a6SAndroid Build Coastguard Worker        # Resolve Identifier
503*523fa7a6SAndroid Build Coastguard Worker        if identifier is None:
504*523fa7a6SAndroid Build Coastguard Worker            if self._generated_identifiers:
505*523fa7a6SAndroid Build Coastguard Worker                identifier = self._next_index
506*523fa7a6SAndroid Build Coastguard Worker                self._next_index += 1
507*523fa7a6SAndroid Build Coastguard Worker            else:
508*523fa7a6SAndroid Build Coastguard Worker                raise Exception(
509*523fa7a6SAndroid Build Coastguard Worker                    "No identifier provided. Failed to add or update entry."
510*523fa7a6SAndroid Build Coastguard Worker                )
511*523fa7a6SAndroid Build Coastguard Worker
512*523fa7a6SAndroid Build Coastguard Worker        # Collect debug handles
513*523fa7a6SAndroid Build Coastguard Worker        if nodes is not None:
514*523fa7a6SAndroid Build Coastguard Worker            new_debug_handles = {
515*523fa7a6SAndroid Build Coastguard Worker                node.meta.get("debug_handle")
516*523fa7a6SAndroid Build Coastguard Worker                for node in (nodes if isinstance(nodes, List) else [nodes])
517*523fa7a6SAndroid Build Coastguard Worker            }
518*523fa7a6SAndroid Build Coastguard Worker        else:
519*523fa7a6SAndroid Build Coastguard Worker            new_debug_handles = (
520*523fa7a6SAndroid Build Coastguard Worker                handles if isinstance(handles, (tuple, List)) else [handles]
521*523fa7a6SAndroid Build Coastguard Worker            )
522*523fa7a6SAndroid Build Coastguard Worker
523*523fa7a6SAndroid Build Coastguard Worker        # Filter for empty debug handles
524*523fa7a6SAndroid Build Coastguard Worker        filtered_debug_handles = {
525*523fa7a6SAndroid Build Coastguard Worker            handle for handle in new_debug_handles if handle is not None
526*523fa7a6SAndroid Build Coastguard Worker        }
527*523fa7a6SAndroid Build Coastguard Worker        if len(filtered_debug_handles) == 0:
528*523fa7a6SAndroid Build Coastguard Worker            raise Exception("No valid debug handles found. Failed to add entry.")
529*523fa7a6SAndroid Build Coastguard Worker
530*523fa7a6SAndroid Build Coastguard Worker        # pyre-ignore Warning from Union[int, st] keys
531*523fa7a6SAndroid Build Coastguard Worker        self._debug_handle_map[identifier] = filtered_debug_handles
532*523fa7a6SAndroid Build Coastguard Worker        return identifier
533*523fa7a6SAndroid Build Coastguard Worker
534*523fa7a6SAndroid Build Coastguard Worker
535*523fa7a6SAndroid Build Coastguard Workerclass WhyNoPartition:
536*523fa7a6SAndroid Build Coastguard Worker    """
537*523fa7a6SAndroid Build Coastguard Worker    Simple helper class for partitioners to log why a node was not lowered.
538*523fa7a6SAndroid Build Coastguard Worker
539*523fa7a6SAndroid Build Coastguard Worker    Example usage:
540*523fa7a6SAndroid Build Coastguard Worker
541*523fa7a6SAndroid Build Coastguard Worker        # In your backend partitioner file(s)
542*523fa7a6SAndroid Build Coastguard Worker        why = WhyNoPartition(logger=your_backend_logger)
543*523fa7a6SAndroid Build Coastguard Worker
544*523fa7a6SAndroid Build Coastguard Worker        # hypothetical function that checks if a node can be lowered
545*523fa7a6SAndroid Build Coastguard Worker        if not can_be_lowered(node):
546*523fa7a6SAndroid Build Coastguard Worker            why(node, "This node was not lowered because ...")
547*523fa7a6SAndroid Build Coastguard Worker    """
548*523fa7a6SAndroid Build Coastguard Worker
549*523fa7a6SAndroid Build Coastguard Worker    def __init__(self, logger: logging.Logger):
550*523fa7a6SAndroid Build Coastguard Worker        self.logger = logger
551*523fa7a6SAndroid Build Coastguard Worker        self.node: Optional[torch.fx.Node] = None
552*523fa7a6SAndroid Build Coastguard Worker        self.reason: str = ""
553*523fa7a6SAndroid Build Coastguard Worker
554*523fa7a6SAndroid Build Coastguard Worker    def __call__(self, node: torch.fx.Node, reason: str) -> None:
555*523fa7a6SAndroid Build Coastguard Worker        self.node = node
556*523fa7a6SAndroid Build Coastguard Worker        self.reason = reason
557*523fa7a6SAndroid Build Coastguard Worker        self.logger.debug(self)
558*523fa7a6SAndroid Build Coastguard Worker
559*523fa7a6SAndroid Build Coastguard Worker    def __str__(self) -> str:
560*523fa7a6SAndroid Build Coastguard Worker        return f"WhyNoPartition: Node {self.node} was not partitioned because {self.reason}."
561