1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Workerimport logging 8*523fa7a6SAndroid Build Coastguard Workerimport operator 9*523fa7a6SAndroid Build Coastguard Workerfrom collections import defaultdict 10*523fa7a6SAndroid Build Coastguard Workerfrom functools import lru_cache 11*523fa7a6SAndroid Build Coastguard Workerfrom typing import Dict, Iterable, List, Optional, Set, Tuple, Union 12*523fa7a6SAndroid Build Coastguard Worker 13*523fa7a6SAndroid Build Coastguard Workerimport torch 14*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.backend_details import ExportedProgram 15*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.canonical_partitioners.duplicate_constant_node_pass import ( 16*523fa7a6SAndroid Build Coastguard Worker duplicate_constant_node, 17*523fa7a6SAndroid Build Coastguard Worker) 18*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.common import setting_python_recursive_limit 19*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.delegate import executorch_call_delegate 20*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects._ops import ops as exir_ops 21*523fa7a6SAndroid Build Coastguard Worker 22*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.lowered_backend_module import create_submodule_from_nodes 23*523fa7a6SAndroid Build Coastguard Workerfrom torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param 24*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx.node import Node 25*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx.passes.utils.source_matcher_utils import SourcePartition 26*523fa7a6SAndroid Build Coastguard Worker 27*523fa7a6SAndroid Build Coastguard WorkerT_QuantPerTensor = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default 28*523fa7a6SAndroid Build Coastguard WorkerT_DQuantPerTensor = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default 29*523fa7a6SAndroid Build Coastguard Worker 30*523fa7a6SAndroid Build Coastguard Worker 31*523fa7a6SAndroid Build Coastguard Worker# NB: Set this to None to handle validation from MobileBert 32*523fa7a6SAndroid Build Coastguard Worker@lru_cache(maxsize=None) 33*523fa7a6SAndroid Build Coastguard Workerdef is_same_node( 34*523fa7a6SAndroid Build Coastguard Worker node_left: Iterable[torch.fx.Node], 35*523fa7a6SAndroid Build Coastguard Worker node_right: Iterable[torch.fx.Node], 36*523fa7a6SAndroid Build Coastguard Worker) -> bool: 37*523fa7a6SAndroid Build Coastguard Worker # two nodes are the same if they have the same target and op 38*523fa7a6SAndroid Build Coastguard Worker # same for their args 39*523fa7a6SAndroid Build Coastguard Worker if isinstance(node_left, torch.fx.Node) and isinstance(node_right, torch.fx.Node): 40*523fa7a6SAndroid Build Coastguard Worker if not ( 41*523fa7a6SAndroid Build Coastguard Worker (node_left.target == node_right.target) 42*523fa7a6SAndroid Build Coastguard Worker and (node_left.op == node_right.op) 43*523fa7a6SAndroid Build Coastguard Worker and (len(node_left.all_input_nodes) == len(node_right.all_input_nodes)) 44*523fa7a6SAndroid Build Coastguard Worker and all( 45*523fa7a6SAndroid Build Coastguard Worker is_same_node(arg_left, arg_right) 46*523fa7a6SAndroid Build Coastguard Worker for arg_left, arg_right in zip( 47*523fa7a6SAndroid Build Coastguard Worker node_left.all_input_nodes, node_right.all_input_nodes 48*523fa7a6SAndroid Build Coastguard Worker ) 49*523fa7a6SAndroid Build Coastguard Worker ) 50*523fa7a6SAndroid Build Coastguard Worker ): 51*523fa7a6SAndroid Build Coastguard Worker return False 52*523fa7a6SAndroid Build Coastguard Worker else: 53*523fa7a6SAndroid Build Coastguard Worker if len(list(node_left)) != len(list(node_right)): 54*523fa7a6SAndroid Build Coastguard Worker return False 55*523fa7a6SAndroid Build Coastguard Worker for n_left, n_right in zip(node_left, node_right): 56*523fa7a6SAndroid Build Coastguard Worker if not is_same_node(n_left, n_right): 57*523fa7a6SAndroid Build Coastguard Worker return False 58*523fa7a6SAndroid Build Coastguard Worker return True 59*523fa7a6SAndroid Build Coastguard Worker 60*523fa7a6SAndroid Build Coastguard Worker 61*523fa7a6SAndroid Build Coastguard Workerdef is_identical_graph( 62*523fa7a6SAndroid Build Coastguard Worker graph_left: torch.fx.GraphModule, graph_right: torch.fx.GraphModule 63*523fa7a6SAndroid Build Coastguard Worker) -> bool: 64*523fa7a6SAndroid Build Coastguard Worker # two graph are the same if they have the same nodes and op. The order of nodes also 65*523fa7a6SAndroid Build Coastguard Worker # matters in this function is more strict. Two graph are not considered as the same 66*523fa7a6SAndroid Build Coastguard Worker # if the topological order of the nodes is the same in this function but the order of nodes 67*523fa7a6SAndroid Build Coastguard Worker # is not the same. 68*523fa7a6SAndroid Build Coastguard Worker if len(list(graph_left.graph.nodes)) != len(list(graph_right.graph.nodes)): 69*523fa7a6SAndroid Build Coastguard Worker return False 70*523fa7a6SAndroid Build Coastguard Worker with setting_python_recursive_limit(30000): 71*523fa7a6SAndroid Build Coastguard Worker for node_left, node_right in zip( 72*523fa7a6SAndroid Build Coastguard Worker graph_left.graph.nodes, graph_right.graph.nodes 73*523fa7a6SAndroid Build Coastguard Worker ): 74*523fa7a6SAndroid Build Coastguard Worker if not (is_same_node(node_left, node_right)): 75*523fa7a6SAndroid Build Coastguard Worker return False 76*523fa7a6SAndroid Build Coastguard Worker return True 77*523fa7a6SAndroid Build Coastguard Worker 78*523fa7a6SAndroid Build Coastguard Worker 79*523fa7a6SAndroid Build Coastguard Workerdef remove_first_quant_and_last_dequant( 80*523fa7a6SAndroid Build Coastguard Worker graph_module: torch.fx.GraphModule, 81*523fa7a6SAndroid Build Coastguard Worker) -> None: 82*523fa7a6SAndroid Build Coastguard Worker for node in graph_module.graph.nodes: 83*523fa7a6SAndroid Build Coastguard Worker if node.target == T_QuantPerTensor: 84*523fa7a6SAndroid Build Coastguard Worker if node.args[0].op == "placeholder": 85*523fa7a6SAndroid Build Coastguard Worker node_users = list(node.users.keys()) 86*523fa7a6SAndroid Build Coastguard Worker for dequant_node in node_users: 87*523fa7a6SAndroid Build Coastguard Worker # point the dequant arg to the placeholder 88*523fa7a6SAndroid Build Coastguard Worker dequant_node.args = (node.args[0],) + dequant_node.args[1:] 89*523fa7a6SAndroid Build Coastguard Worker elif node.target == T_DQuantPerTensor: 90*523fa7a6SAndroid Build Coastguard Worker node_users = list(node.users.keys()) 91*523fa7a6SAndroid Build Coastguard Worker if node_users[0].op == "output": 92*523fa7a6SAndroid Build Coastguard Worker # point the output arg to the quant node 93*523fa7a6SAndroid Build Coastguard Worker output_node = node_users[0] 94*523fa7a6SAndroid Build Coastguard Worker output_node.args = ([node.args[0]],) 95*523fa7a6SAndroid Build Coastguard Worker # Remove the quant/dequant nodes as they don't have users 96*523fa7a6SAndroid Build Coastguard Worker graph_module.graph.eliminate_dead_code() 97*523fa7a6SAndroid Build Coastguard Worker graph_module.recompile() 98*523fa7a6SAndroid Build Coastguard Worker 99*523fa7a6SAndroid Build Coastguard Worker 100*523fa7a6SAndroid Build Coastguard Worker# TODO - use edge ops 101*523fa7a6SAndroid Build Coastguard Workerdef replace_quantized_partition_with_op( 102*523fa7a6SAndroid Build Coastguard Worker graph_module: torch.fx.GraphModule, 103*523fa7a6SAndroid Build Coastguard Worker partition: SourcePartition, 104*523fa7a6SAndroid Build Coastguard Worker replacement_op: torch._ops.OpOverloadPacket, 105*523fa7a6SAndroid Build Coastguard Worker) -> Tuple[torch.fx.Node, List[torch.fx.Node], List[torch.fx.Node]]: 106*523fa7a6SAndroid Build Coastguard Worker """ 107*523fa7a6SAndroid Build Coastguard Worker Replaces partition with the op specified by replacement_op. It's also expected that 108*523fa7a6SAndroid Build Coastguard Worker the nodes contained in partition are sourced from a quantized module as this function 109*523fa7a6SAndroid Build Coastguard Worker searches for the quantization pattern to consume along with the nodes in the partition, 110*523fa7a6SAndroid Build Coastguard Worker to be then replaced by replacement_op. 111*523fa7a6SAndroid Build Coastguard Worker 112*523fa7a6SAndroid Build Coastguard Worker Args: 113*523fa7a6SAndroid Build Coastguard Worker graph_module: The graph module from which this partition was sourced. 114*523fa7a6SAndroid Build Coastguard Worker partition: Partition to be replaced. 115*523fa7a6SAndroid Build Coastguard Worker replacement_op: The op to replace paritition with. 116*523fa7a6SAndroid Build Coastguard Worker Returns: 117*523fa7a6SAndroid Build Coastguard Worker Tuple: First element in the tuple is the new replaced module. The second and third 118*523fa7a6SAndroid Build Coastguard Worker node lists in the returned tuple consist of the dq and q nodes that were consumed 119*523fa7a6SAndroid Build Coastguard Worker along with this partition to be replaced by the replacement_op. 120*523fa7a6SAndroid Build Coastguard Worker """ 121*523fa7a6SAndroid Build Coastguard Worker 122*523fa7a6SAndroid Build Coastguard Worker dequant_nodes = [] 123*523fa7a6SAndroid Build Coastguard Worker quant_nodes = [] 124*523fa7a6SAndroid Build Coastguard Worker input_nodes = [] 125*523fa7a6SAndroid Build Coastguard Worker output_nodes = [] 126*523fa7a6SAndroid Build Coastguard Worker 127*523fa7a6SAndroid Build Coastguard Worker partition_nodes = [node for node in partition.nodes if node not in partition.params] 128*523fa7a6SAndroid Build Coastguard Worker 129*523fa7a6SAndroid Build Coastguard Worker # We recreate our input nodes and output nodes list instead of using partition.input_nodes 130*523fa7a6SAndroid Build Coastguard Worker # and partition.output_nodes as the ordering of the nodes in those lists is not deterministic, 131*523fa7a6SAndroid Build Coastguard Worker # whereas for the quant fusion pass we expect deterministic ordering. 132*523fa7a6SAndroid Build Coastguard Worker for node in partition.nodes: 133*523fa7a6SAndroid Build Coastguard Worker for arg in node.args: 134*523fa7a6SAndroid Build Coastguard Worker if isinstance(arg, torch.fx.Node) and (arg not in partition.nodes): 135*523fa7a6SAndroid Build Coastguard Worker input_nodes.append(arg) 136*523fa7a6SAndroid Build Coastguard Worker 137*523fa7a6SAndroid Build Coastguard Worker for user in node.users.keys(): 138*523fa7a6SAndroid Build Coastguard Worker if user not in partition.nodes: 139*523fa7a6SAndroid Build Coastguard Worker output_nodes.append(node) 140*523fa7a6SAndroid Build Coastguard Worker 141*523fa7a6SAndroid Build Coastguard Worker # Try to find all the dq nodes that are feeding into this module. 142*523fa7a6SAndroid Build Coastguard Worker for node in input_nodes: 143*523fa7a6SAndroid Build Coastguard Worker if node.target == T_DQuantPerTensor: 144*523fa7a6SAndroid Build Coastguard Worker dequant_nodes += [node] 145*523fa7a6SAndroid Build Coastguard Worker 146*523fa7a6SAndroid Build Coastguard Worker # Try to find all the q nodes that this module is feeding out into. 147*523fa7a6SAndroid Build Coastguard Worker for node in output_nodes: 148*523fa7a6SAndroid Build Coastguard Worker for user in node.users.keys(): 149*523fa7a6SAndroid Build Coastguard Worker if user.target == T_QuantPerTensor: 150*523fa7a6SAndroid Build Coastguard Worker quant_nodes += [user] 151*523fa7a6SAndroid Build Coastguard Worker 152*523fa7a6SAndroid Build Coastguard Worker assert len(dequant_nodes) >= 1, "Dequant nodes missing in node list to be replaced." 153*523fa7a6SAndroid Build Coastguard Worker assert len(quant_nodes) >= 1, "Quant nodes missing in node list to be replaced." 154*523fa7a6SAndroid Build Coastguard Worker 155*523fa7a6SAndroid Build Coastguard Worker # After this, node list will essentially contain all the nodes in the 156*523fa7a6SAndroid Build Coastguard Worker # dq->op->q pattern that we will want to replace with a custom backend op. 157*523fa7a6SAndroid Build Coastguard Worker node_list = dequant_nodes + partition_nodes + quant_nodes 158*523fa7a6SAndroid Build Coastguard Worker 159*523fa7a6SAndroid Build Coastguard Worker submodule, call_module_node = create_submodule_from_nodes( 160*523fa7a6SAndroid Build Coastguard Worker graph_module, node_list, "to_be_replaced", skip_legalize_graph=True 161*523fa7a6SAndroid Build Coastguard Worker ) 162*523fa7a6SAndroid Build Coastguard Worker 163*523fa7a6SAndroid Build Coastguard Worker # Update the replaced op so that we have all the latest args and kwargs. 164*523fa7a6SAndroid Build Coastguard Worker with graph_module.graph.inserting_before(call_module_node): 165*523fa7a6SAndroid Build Coastguard Worker replaced_op = graph_module.graph.call_function( 166*523fa7a6SAndroid Build Coastguard Worker replacement_op, 167*523fa7a6SAndroid Build Coastguard Worker call_module_node.args, 168*523fa7a6SAndroid Build Coastguard Worker kwargs=call_module_node.kwargs, 169*523fa7a6SAndroid Build Coastguard Worker ) 170*523fa7a6SAndroid Build Coastguard Worker call_module_node.replace_all_uses_with(replaced_op) 171*523fa7a6SAndroid Build Coastguard Worker graph_module.graph.erase_node(call_module_node) 172*523fa7a6SAndroid Build Coastguard Worker replaced_op.meta = call_module_node.meta 173*523fa7a6SAndroid Build Coastguard Worker graph_module.recompile() 174*523fa7a6SAndroid Build Coastguard Worker 175*523fa7a6SAndroid Build Coastguard Worker return (replaced_op, dequant_nodes, quant_nodes) 176*523fa7a6SAndroid Build Coastguard Worker 177*523fa7a6SAndroid Build Coastguard Worker 178*523fa7a6SAndroid Build Coastguard Workerdef _assign_new_tag( 179*523fa7a6SAndroid Build Coastguard Worker tagged_exported_program: ExportedProgram, 180*523fa7a6SAndroid Build Coastguard Worker copied_nodes: Set[str], 181*523fa7a6SAndroid Build Coastguard Worker): 182*523fa7a6SAndroid Build Coastguard Worker """ 183*523fa7a6SAndroid Build Coastguard Worker Assign new tag to the copied nodes. 184*523fa7a6SAndroid Build Coastguard Worker 185*523fa7a6SAndroid Build Coastguard Worker Before the pass 186*523fa7a6SAndroid Build Coastguard Worker constant_0 (tag_10) ------------------> op_b (tag_10) 187*523fa7a6SAndroid Build Coastguard Worker constant_0_copy (tag_10) -------------> op_a (tag_11) 188*523fa7a6SAndroid Build Coastguard Worker 189*523fa7a6SAndroid Build Coastguard Worker After the pass 190*523fa7a6SAndroid Build Coastguard Worker constant_0 (tag_10) ------------------> op_b (tag_10) 191*523fa7a6SAndroid Build Coastguard Worker constant_0_copy (tag_11) -------------> op_a (tag_11) 192*523fa7a6SAndroid Build Coastguard Worker 193*523fa7a6SAndroid Build Coastguard Worker """ 194*523fa7a6SAndroid Build Coastguard Worker for node in tagged_exported_program.graph.nodes: 195*523fa7a6SAndroid Build Coastguard Worker if node.op == "placeholder": 196*523fa7a6SAndroid Build Coastguard Worker if node.name in copied_nodes: 197*523fa7a6SAndroid Build Coastguard Worker users_tag = set() 198*523fa7a6SAndroid Build Coastguard Worker for user in node.users: 199*523fa7a6SAndroid Build Coastguard Worker users_tag.add(user.meta.get("delegation_tag", None)) 200*523fa7a6SAndroid Build Coastguard Worker # Assign the tag to the copy constant node the same as their users. 201*523fa7a6SAndroid Build Coastguard Worker if len(users_tag) == 1: 202*523fa7a6SAndroid Build Coastguard Worker node.meta["delegation_tag"] = users_tag.pop() 203*523fa7a6SAndroid Build Coastguard Worker 204*523fa7a6SAndroid Build Coastguard Worker 205*523fa7a6SAndroid Build Coastguard Workerdef _maybe_duplicate_constant_nodes( 206*523fa7a6SAndroid Build Coastguard Worker tagged_exported_program: ExportedProgram, 207*523fa7a6SAndroid Build Coastguard Worker tag: str, 208*523fa7a6SAndroid Build Coastguard Worker) -> None: 209*523fa7a6SAndroid Build Coastguard Worker """ 210*523fa7a6SAndroid Build Coastguard Worker If the constants node is shared by different tagged nodes, like 211*523fa7a6SAndroid Build Coastguard Worker constant_0 ----> op_b (tag_10) 212*523fa7a6SAndroid Build Coastguard Worker |-------------> op_a (tag_11) 213*523fa7a6SAndroid Build Coastguard Worker 214*523fa7a6SAndroid Build Coastguard Worker we make default as constant_0 is duplicated to constant_0_1, constant_0_2, unless the node is tagged with "no_copy" 215*523fa7a6SAndroid Build Coastguard Worker constant_0 ------------------> op_b (tag_10) 216*523fa7a6SAndroid Build Coastguard Worker constant_0_copy -------------> op_a (tag_11) 217*523fa7a6SAndroid Build Coastguard Worker 218*523fa7a6SAndroid Build Coastguard Worker backend can estimate how much they want to duplicate the constant node, either error out or default to duplicate 219*523fa7a6SAndroid Build Coastguard Worker """ 220*523fa7a6SAndroid Build Coastguard Worker candidate_nodes = set() 221*523fa7a6SAndroid Build Coastguard Worker for node in tagged_exported_program.graph.nodes: 222*523fa7a6SAndroid Build Coastguard Worker if node.meta.get("delegation_tag", "") == tag: 223*523fa7a6SAndroid Build Coastguard Worker if node.op == "placeholder": 224*523fa7a6SAndroid Build Coastguard Worker for user in node.users: 225*523fa7a6SAndroid Build Coastguard Worker users_tag = user.meta.get("delegation_tag", None) 226*523fa7a6SAndroid Build Coastguard Worker if users_tag != tag: 227*523fa7a6SAndroid Build Coastguard Worker # If the node is tagged with "no_copy", we stop duplicating it and throw an error 228*523fa7a6SAndroid Build Coastguard Worker if node.meta.get("no_copy", False): 229*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError( 230*523fa7a6SAndroid Build Coastguard Worker f"constant data node ({node}) is tagged with ({tag}) but has user ({user}) which has tag ({users_tag})" 231*523fa7a6SAndroid Build Coastguard Worker ) 232*523fa7a6SAndroid Build Coastguard Worker else: 233*523fa7a6SAndroid Build Coastguard Worker candidate_nodes.add(node.name) 234*523fa7a6SAndroid Build Coastguard Worker copied_nodes = set() 235*523fa7a6SAndroid Build Coastguard Worker for candidate_node in candidate_nodes: 236*523fa7a6SAndroid Build Coastguard Worker # Both tagged exported program and the owning program need to go through the same duplication pass 237*523fa7a6SAndroid Build Coastguard Worker copied_nodes = copied_nodes.union( 238*523fa7a6SAndroid Build Coastguard Worker duplicate_constant_node(tagged_exported_program, candidate_node) 239*523fa7a6SAndroid Build Coastguard Worker ) 240*523fa7a6SAndroid Build Coastguard Worker candidate_node_with_copies = candidate_nodes.union(copied_nodes) 241*523fa7a6SAndroid Build Coastguard Worker _assign_new_tag(tagged_exported_program, candidate_node_with_copies) 242*523fa7a6SAndroid Build Coastguard Worker 243*523fa7a6SAndroid Build Coastguard Worker 244*523fa7a6SAndroid Build Coastguard Workerdef _get_item_from_executorch_call_delegate(node: torch.fx.Node) -> bool: 245*523fa7a6SAndroid Build Coastguard Worker """ 246*523fa7a6SAndroid Build Coastguard Worker Check if the node is the getitem followed by executorch_call_delegate node. These getitems node 247*523fa7a6SAndroid Build Coastguard Worker are just for getting the result from delegate because the input/output to delegates are flattened 248*523fa7a6SAndroid Build Coastguard Worker """ 249*523fa7a6SAndroid Build Coastguard Worker return ( 250*523fa7a6SAndroid Build Coastguard Worker node.target == operator.getitem 251*523fa7a6SAndroid Build Coastguard Worker and len(node.args) == 2 252*523fa7a6SAndroid Build Coastguard Worker and node.args[0].target == executorch_call_delegate # pyre-ignore 253*523fa7a6SAndroid Build Coastguard Worker and isinstance(node.args[1], int) 254*523fa7a6SAndroid Build Coastguard Worker ) 255*523fa7a6SAndroid Build Coastguard Worker 256*523fa7a6SAndroid Build Coastguard Worker 257*523fa7a6SAndroid Build Coastguard Workerdef get_non_lowered_nodes(graph: torch.fx.Graph) -> List[torch.fx.Node]: 258*523fa7a6SAndroid Build Coastguard Worker """ 259*523fa7a6SAndroid Build Coastguard Worker Returns a list of non lowered nodes in the graph module. 260*523fa7a6SAndroid Build Coastguard Worker """ 261*523fa7a6SAndroid Build Coastguard Worker return [ 262*523fa7a6SAndroid Build Coastguard Worker node 263*523fa7a6SAndroid Build Coastguard Worker for node in graph.nodes 264*523fa7a6SAndroid Build Coastguard Worker if node.op == "call_function" 265*523fa7a6SAndroid Build Coastguard Worker and node.target != executorch_call_delegate 266*523fa7a6SAndroid Build Coastguard Worker and (not _get_item_from_executorch_call_delegate(node)) 267*523fa7a6SAndroid Build Coastguard Worker ] 268*523fa7a6SAndroid Build Coastguard Worker 269*523fa7a6SAndroid Build Coastguard Worker 270*523fa7a6SAndroid Build Coastguard Workerdef get_delegates(graph: torch.fx.Graph) -> List[torch.fx.Node]: 271*523fa7a6SAndroid Build Coastguard Worker """ 272*523fa7a6SAndroid Build Coastguard Worker Returns the list of delegates from the graph. 273*523fa7a6SAndroid Build Coastguard Worker """ 274*523fa7a6SAndroid Build Coastguard Worker return [ 275*523fa7a6SAndroid Build Coastguard Worker node 276*523fa7a6SAndroid Build Coastguard Worker for node in graph.nodes 277*523fa7a6SAndroid Build Coastguard Worker if node.op == "get_attr" and node.name.startswith("lowered_module_") 278*523fa7a6SAndroid Build Coastguard Worker ] 279*523fa7a6SAndroid Build Coastguard Worker 280*523fa7a6SAndroid Build Coastguard Worker 281*523fa7a6SAndroid Build Coastguard Workerdef print_delegated_graph(graph_module: torch.fx.GraphModule) -> None: 282*523fa7a6SAndroid Build Coastguard Worker """ 283*523fa7a6SAndroid Build Coastguard Worker Print the formatted graph string. 284*523fa7a6SAndroid Build Coastguard Worker """ 285*523fa7a6SAndroid Build Coastguard Worker print(format_delegated_graph(graph_module)) 286*523fa7a6SAndroid Build Coastguard Worker 287*523fa7a6SAndroid Build Coastguard Worker 288*523fa7a6SAndroid Build Coastguard Workerdef format_delegated_graph(graph_module: torch.fx.GraphModule) -> str: 289*523fa7a6SAndroid Build Coastguard Worker """ 290*523fa7a6SAndroid Build Coastguard Worker Return the formatted graph string of including lowered_module (both backend id and original graph) together with the graph module. Example output: 291*523fa7a6SAndroid Build Coastguard Worker graph(): 292*523fa7a6SAndroid Build Coastguard Worker %arg0_1 : [num_users=2] = placeholder[target=arg0_1] 293*523fa7a6SAndroid Build Coastguard Worker %arg1_1 : [num_users=2] = placeholder[target=arg1_1] 294*523fa7a6SAndroid Build Coastguard Worker %arg2_1 : [num_users=2] = placeholder[target=arg2_1] 295*523fa7a6SAndroid Build Coastguard Worker %lowered_module_0 : [num_users=1] = get_attr[target=lowered_module_0] 296*523fa7a6SAndroid Build Coastguard Worker backend_id: BackendWithCompilerDemo 297*523fa7a6SAndroid Build Coastguard Worker lowered graph(): 298*523fa7a6SAndroid Build Coastguard Worker %arg0_1 : [num_users=1] = placeholder[target=arg0_1] 299*523fa7a6SAndroid Build Coastguard Worker %arg1_1 : [num_users=1] = placeholder[target=arg1_1] 300*523fa7a6SAndroid Build Coastguard Worker %arg2_1 : [num_users=1] = placeholder[target=arg2_1] 301*523fa7a6SAndroid Build Coastguard Worker %aten_mm_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.mm.default](args = (%arg0_1, %arg1_1), kwargs = {}) 302*523fa7a6SAndroid Build Coastguard Worker %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%aten_mm_default, %arg2_1), kwargs = {}) 303*523fa7a6SAndroid Build Coastguard Worker return [aten_add_tensor] 304*523fa7a6SAndroid Build Coastguard Worker %executorch_call_delegate : [num_users=1] = call_function[target=torch.ops.higher_order.executorch_call_delegate](args = (%lowered_module_0, %arg0_1, %arg1_1, %arg2_1), kwargs = {}) 305*523fa7a6SAndroid Build Coastguard Worker %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%executorch_call_delegate, 0), kwargs = {}) 306*523fa7a6SAndroid Build Coastguard Worker %aten_sub_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.sub.Tensor](args = (%getitem, %arg0_1), kwargs = {}) 307*523fa7a6SAndroid Build Coastguard Worker %lowered_module_1 : [num_users=1] = get_attr[target=lowered_module_1] 308*523fa7a6SAndroid Build Coastguard Worker backend_id: BackendWithCompilerDemo 309*523fa7a6SAndroid Build Coastguard Worker lowered graph(): 310*523fa7a6SAndroid Build Coastguard Worker %aten_sub_tensor : [num_users=1] = placeholder[target=aten_sub_tensor] 311*523fa7a6SAndroid Build Coastguard Worker %arg1_1 : [num_users=1] = placeholder[target=arg1_1] 312*523fa7a6SAndroid Build Coastguard Worker %arg2_1 : [num_users=1] = placeholder[target=arg2_1] 313*523fa7a6SAndroid Build Coastguard Worker %aten_mm_default_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.mm.default](args = (%aten_sub_tensor, %arg1_1), kwargs = {}) 314*523fa7a6SAndroid Build Coastguard Worker %aten_add_tensor_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%aten_mm_default_1, %arg2_1), kwargs = {}) 315*523fa7a6SAndroid Build Coastguard Worker return [aten_add_tensor_1] 316*523fa7a6SAndroid Build Coastguard Worker %executorch_call_delegate_1 : [num_users=1] = call_function[target=torch.ops.higher_order.executorch_call_delegate](args = (%lowered_module_1, %aten_sub_tensor, %arg1_1, %arg2_1), kwargs = {}) 317*523fa7a6SAndroid Build Coastguard Worker %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%executorch_call_delegate_1, 0), kwargs = {}) 318*523fa7a6SAndroid Build Coastguard Worker return [getitem_1] 319*523fa7a6SAndroid Build Coastguard Worker """ 320*523fa7a6SAndroid Build Coastguard Worker lowered_module_dict = { 321*523fa7a6SAndroid Build Coastguard Worker node.name: getattr(graph_module, node.name) 322*523fa7a6SAndroid Build Coastguard Worker for node in graph_module.graph.nodes 323*523fa7a6SAndroid Build Coastguard Worker if node.op == "get_attr" and node.name.startswith("lowered_module_") 324*523fa7a6SAndroid Build Coastguard Worker } 325*523fa7a6SAndroid Build Coastguard Worker indent = " " 326*523fa7a6SAndroid Build Coastguard Worker graph_format_str = "graph():\n" 327*523fa7a6SAndroid Build Coastguard Worker for node in graph_module.graph.nodes: 328*523fa7a6SAndroid Build Coastguard Worker graph_format_str += f"{indent}{node.format_node()}\n" 329*523fa7a6SAndroid Build Coastguard Worker if node.op == "get_attr" and node.name.startswith("lowered_module_"): 330*523fa7a6SAndroid Build Coastguard Worker lowered_module = lowered_module_dict[node.name] 331*523fa7a6SAndroid Build Coastguard Worker graph_format_str += f"{indent * 2}backend_id: {lowered_module.backend_id}\n" 332*523fa7a6SAndroid Build Coastguard Worker graph_format_str += f"{indent * 2}lowered graph():\n" 333*523fa7a6SAndroid Build Coastguard Worker for node_in_lowered_module in lowered_module.original_module.graph.nodes: 334*523fa7a6SAndroid Build Coastguard Worker graph_format_str += ( 335*523fa7a6SAndroid Build Coastguard Worker f"{indent * 3}{node_in_lowered_module.format_node()}\n" 336*523fa7a6SAndroid Build Coastguard Worker ) 337*523fa7a6SAndroid Build Coastguard Worker return graph_format_str 338*523fa7a6SAndroid Build Coastguard Worker 339*523fa7a6SAndroid Build Coastguard Worker 340*523fa7a6SAndroid Build Coastguard Workerdef tag_constant_data(edge_program: ExportedProgram) -> None: 341*523fa7a6SAndroid Build Coastguard Worker """ 342*523fa7a6SAndroid Build Coastguard Worker Util function for partitioners. This function tags the const/param/buffers nodes 343*523fa7a6SAndroid Build Coastguard Worker whose users all belong within the same partition. This should be called after tagging all other nodes. 344*523fa7a6SAndroid Build Coastguard Worker Any const/param/buffer which is used as input to a subgraph, will be tagged with the same tag as that 345*523fa7a6SAndroid Build Coastguard Worker subgraph. Throw error when const/param/buffers is used across different partitions. That is the 346*523fa7a6SAndroid Build Coastguard Worker underlying data will be owned by multiple delegates. 347*523fa7a6SAndroid Build Coastguard Worker """ 348*523fa7a6SAndroid Build Coastguard Worker mutated_buffer = set() 349*523fa7a6SAndroid Build Coastguard Worker for node in edge_program.graph.nodes: 350*523fa7a6SAndroid Build Coastguard Worker if node.op == "placeholder" and ( 351*523fa7a6SAndroid Build Coastguard Worker is_param(edge_program, node) 352*523fa7a6SAndroid Build Coastguard Worker or is_buffer(edge_program, node) 353*523fa7a6SAndroid Build Coastguard Worker or is_lifted_tensor_constant(edge_program, node) 354*523fa7a6SAndroid Build Coastguard Worker ): 355*523fa7a6SAndroid Build Coastguard Worker for node_user in node.users: 356*523fa7a6SAndroid Build Coastguard Worker if node_user.name in edge_program.graph_signature.buffers_to_mutate: 357*523fa7a6SAndroid Build Coastguard Worker logging.info( 358*523fa7a6SAndroid Build Coastguard Worker "The buffer node is a mutated buffer node, which is not constant." 359*523fa7a6SAndroid Build Coastguard Worker ) 360*523fa7a6SAndroid Build Coastguard Worker mutated_buffer.add(node) 361*523fa7a6SAndroid Build Coastguard Worker 362*523fa7a6SAndroid Build Coastguard Worker for node in edge_program.graph.nodes: 363*523fa7a6SAndroid Build Coastguard Worker # go through const/param/buffer nodes, if all users of const/param/buffer nodes are partitioned then partition 364*523fa7a6SAndroid Build Coastguard Worker if node.op == "placeholder" and ( 365*523fa7a6SAndroid Build Coastguard Worker is_param(edge_program, node) 366*523fa7a6SAndroid Build Coastguard Worker or is_buffer(edge_program, node) 367*523fa7a6SAndroid Build Coastguard Worker or is_lifted_tensor_constant(edge_program, node) 368*523fa7a6SAndroid Build Coastguard Worker ): 369*523fa7a6SAndroid Build Coastguard Worker if node not in mutated_buffer: 370*523fa7a6SAndroid Build Coastguard Worker user_tags = set() 371*523fa7a6SAndroid Build Coastguard Worker for user in node.users: 372*523fa7a6SAndroid Build Coastguard Worker user_tag = user.meta.get("delegation_tag", None) 373*523fa7a6SAndroid Build Coastguard Worker if user_tag is not None: 374*523fa7a6SAndroid Build Coastguard Worker user_tags.add(user_tag) 375*523fa7a6SAndroid Build Coastguard Worker if len(user_tags) > 1: 376*523fa7a6SAndroid Build Coastguard Worker logging.info( 377*523fa7a6SAndroid Build Coastguard Worker f"The data node is used across multiple partitions, including {user_tags}. " 378*523fa7a6SAndroid Build Coastguard Worker "If the data is too large and it's not preferred to copy, please tag the " 379*523fa7a6SAndroid Build Coastguard Worker "constant node like node.['no_copy'] = True and they won't be copied." 380*523fa7a6SAndroid Build Coastguard Worker ) 381*523fa7a6SAndroid Build Coastguard Worker # tag the data node with the same tag as the last user 382*523fa7a6SAndroid Build Coastguard Worker if len(user_tags) > 0: 383*523fa7a6SAndroid Build Coastguard Worker node.meta["delegation_tag"] = user_tags.pop() 384*523fa7a6SAndroid Build Coastguard Worker 385*523fa7a6SAndroid Build Coastguard Worker 386*523fa7a6SAndroid Build Coastguard Workerdef tag_mutated_buffer(edge_program: ExportedProgram) -> None: 387*523fa7a6SAndroid Build Coastguard Worker """ 388*523fa7a6SAndroid Build Coastguard Worker Util function for partitioners. This function tags the mutated buffer nodes 389*523fa7a6SAndroid Build Coastguard Worker whose users all belong within the same partition. This should be called after tagging all other nodes. 390*523fa7a6SAndroid Build Coastguard Worker Any buffer which is used as input to a subgraph, will be tagged with the same tag as that 391*523fa7a6SAndroid Build Coastguard Worker subgraph. Throw error when buffers is used across different partitions. That is the 392*523fa7a6SAndroid Build Coastguard Worker underlying data will be owned by multiple delegates. 393*523fa7a6SAndroid Build Coastguard Worker """ 394*523fa7a6SAndroid Build Coastguard Worker for node in edge_program.graph.nodes: 395*523fa7a6SAndroid Build Coastguard Worker # Determine whether this node is a mutated buffer 396*523fa7a6SAndroid Build Coastguard Worker is_mutated_buffer_node = False 397*523fa7a6SAndroid Build Coastguard Worker if node.op == "placeholder" and is_buffer(edge_program, node): 398*523fa7a6SAndroid Build Coastguard Worker for node_user in node.users: 399*523fa7a6SAndroid Build Coastguard Worker if node_user.name in edge_program.graph_signature.buffers_to_mutate: 400*523fa7a6SAndroid Build Coastguard Worker is_mutated_buffer_node = True 401*523fa7a6SAndroid Build Coastguard Worker break 402*523fa7a6SAndroid Build Coastguard Worker # This node is mutated buffer, tag it 403*523fa7a6SAndroid Build Coastguard Worker if is_mutated_buffer_node: 404*523fa7a6SAndroid Build Coastguard Worker user_tags = set() 405*523fa7a6SAndroid Build Coastguard Worker for user in node.users: 406*523fa7a6SAndroid Build Coastguard Worker user_tag = user.meta.get("delegation_tag", None) 407*523fa7a6SAndroid Build Coastguard Worker if user_tag is not None: 408*523fa7a6SAndroid Build Coastguard Worker user_tags.add(user_tag) 409*523fa7a6SAndroid Build Coastguard Worker if len(user_tags) > 1: 410*523fa7a6SAndroid Build Coastguard Worker logging.info( 411*523fa7a6SAndroid Build Coastguard Worker f"The data node is used across multiple partitions, including {user_tags}. " 412*523fa7a6SAndroid Build Coastguard Worker "If the data is too large and it's not preferred to copy, please tag the " 413*523fa7a6SAndroid Build Coastguard Worker "constant node like node.['no_copy'] = True and they won't be copied." 414*523fa7a6SAndroid Build Coastguard Worker ) 415*523fa7a6SAndroid Build Coastguard Worker # tag the data node with the same tag as the last user 416*523fa7a6SAndroid Build Coastguard Worker if len(user_tags) > 0: 417*523fa7a6SAndroid Build Coastguard Worker node.meta["delegation_tag"] = user_tags.pop() 418*523fa7a6SAndroid Build Coastguard Worker 419*523fa7a6SAndroid Build Coastguard Worker 420*523fa7a6SAndroid Build Coastguard Worker# TODO - style: use templated types 421*523fa7a6SAndroid Build Coastguard Workerclass DelegateMappingBuilder: 422*523fa7a6SAndroid Build Coastguard Worker """ 423*523fa7a6SAndroid Build Coastguard Worker Profiling helper class for building Delegate Mappings. 424*523fa7a6SAndroid Build Coastguard Worker Delegate Mappings are mappings from delegate debug identifiers to node 425*523fa7a6SAndroid Build Coastguard Worker debug handles. Specifically this is used to log within backend delegates 426*523fa7a6SAndroid Build Coastguard Worker 427*523fa7a6SAndroid Build Coastguard Worker Args: 428*523fa7a6SAndroid Build Coastguard Worker generated_identifiers (bool, optional): Whether identifier keys are 429*523fa7a6SAndroid Build Coastguard Worker generated automatically. Defaults to False. 430*523fa7a6SAndroid Build Coastguard Worker """ 431*523fa7a6SAndroid Build Coastguard Worker 432*523fa7a6SAndroid Build Coastguard Worker def __init__(self, generated_identifiers: bool = False): 433*523fa7a6SAndroid Build Coastguard Worker self._generated_identifiers = generated_identifiers 434*523fa7a6SAndroid Build Coastguard Worker 435*523fa7a6SAndroid Build Coastguard Worker # Note that the internal struct has a Set value, while the getter 436*523fa7a6SAndroid Build Coastguard Worker # function returns the values as a tuple 437*523fa7a6SAndroid Build Coastguard Worker self._debug_handle_map: Union[Dict[int, Set[int]], Dict[str, Set[int]]] = ( 438*523fa7a6SAndroid Build Coastguard Worker defaultdict(set) 439*523fa7a6SAndroid Build Coastguard Worker ) 440*523fa7a6SAndroid Build Coastguard Worker self._next_index: int = 0 441*523fa7a6SAndroid Build Coastguard Worker 442*523fa7a6SAndroid Build Coastguard Worker def get_delegate_mapping( 443*523fa7a6SAndroid Build Coastguard Worker self, 444*523fa7a6SAndroid Build Coastguard Worker ) -> Union[Dict[int, Tuple[int]], Dict[str, Tuple[int]]]: 445*523fa7a6SAndroid Build Coastguard Worker """ 446*523fa7a6SAndroid Build Coastguard Worker Returns: 447*523fa7a6SAndroid Build Coastguard Worker Union[Dict[int, Tuple[int]], Dict[str, Tuple[int]]]: 448*523fa7a6SAndroid Build Coastguard Worker A map of delegate debug identifier to a list of debug handles 449*523fa7a6SAndroid Build Coastguard Worker The keys (identifier) are either integers or strings 450*523fa7a6SAndroid Build Coastguard Worker The values are a sorted tuple of integer debug handles 451*523fa7a6SAndroid Build Coastguard Worker """ 452*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore Warning between Union[Dict[K, V], Dict[K2, V]] vs Dict[Union[K, K2], V] 453*523fa7a6SAndroid Build Coastguard Worker return {k: tuple(sorted(v)) for k, v in self._debug_handle_map.items()} 454*523fa7a6SAndroid Build Coastguard Worker 455*523fa7a6SAndroid Build Coastguard Worker def insert_delegate_mapping_entry( 456*523fa7a6SAndroid Build Coastguard Worker self, 457*523fa7a6SAndroid Build Coastguard Worker nodes: Optional[Union[Node, List[Node]]] = None, 458*523fa7a6SAndroid Build Coastguard Worker handles: Optional[Union[int, List[Optional[int]]]] = None, 459*523fa7a6SAndroid Build Coastguard Worker identifier: Optional[Union[int, str]] = None, 460*523fa7a6SAndroid Build Coastguard Worker ) -> Union[int, str]: 461*523fa7a6SAndroid Build Coastguard Worker """ 462*523fa7a6SAndroid Build Coastguard Worker Add a new delegate mapping entry 463*523fa7a6SAndroid Build Coastguard Worker 464*523fa7a6SAndroid Build Coastguard Worker If self._generated_identifiers = False: 465*523fa7a6SAndroid Build Coastguard Worker - A new identifier must be provided, else an exception is thrown 466*523fa7a6SAndroid Build Coastguard Worker 467*523fa7a6SAndroid Build Coastguard Worker If self._generated_identifiers = True: 468*523fa7a6SAndroid Build Coastguard Worker - New identifiers are generated incrementally, 0 indexed 469*523fa7a6SAndroid Build Coastguard Worker - Identifiers cannot be manually provided, else an exception is thrown 470*523fa7a6SAndroid Build Coastguard Worker 471*523fa7a6SAndroid Build Coastguard Worker Args: 472*523fa7a6SAndroid Build Coastguard Worker nodes (Union[Node, List[Node]]): A (list of) Node(s) 473*523fa7a6SAndroid Build Coastguard Worker handles (Union[int, List[Optional[int]]]): A (list of) debug handle(s) 474*523fa7a6SAndroid Build Coastguard Worker identifier (Optional[Union[int, str]]): 475*523fa7a6SAndroid Build Coastguard Worker Debug identifier corresponding to the Node(s) 476*523fa7a6SAndroid Build Coastguard Worker 477*523fa7a6SAndroid Build Coastguard Worker Note: Exactly one of nodes and handles must be provided 478*523fa7a6SAndroid Build Coastguard Worker Note: If a debug handle is missing or None, it is skipped 479*523fa7a6SAndroid Build Coastguard Worker 480*523fa7a6SAndroid Build Coastguard Worker Returns: 481*523fa7a6SAndroid Build Coastguard Worker Union[int, str]: 482*523fa7a6SAndroid Build Coastguard Worker Delegate debug identifier inserted 483*523fa7a6SAndroid Build Coastguard Worker """ 484*523fa7a6SAndroid Build Coastguard Worker 485*523fa7a6SAndroid Build Coastguard Worker # Check for manual addition of identifier (with generated identifiers enabled) 486*523fa7a6SAndroid Build Coastguard Worker if self._generated_identifiers and identifier is not None: 487*523fa7a6SAndroid Build Coastguard Worker raise Exception( 488*523fa7a6SAndroid Build Coastguard Worker f"Builders using generated identifiers can't manually add identifiers: {identifier}. Failed to add or update entry" 489*523fa7a6SAndroid Build Coastguard Worker ) 490*523fa7a6SAndroid Build Coastguard Worker 491*523fa7a6SAndroid Build Coastguard Worker if identifier is not None and identifier in self._debug_handle_map: 492*523fa7a6SAndroid Build Coastguard Worker raise Exception( 493*523fa7a6SAndroid Build Coastguard Worker "This delegate debug identifier was already inserted. Duplicate delegate debug identifiers are not allowed." 494*523fa7a6SAndroid Build Coastguard Worker ) 495*523fa7a6SAndroid Build Coastguard Worker 496*523fa7a6SAndroid Build Coastguard Worker # Check for exactly one of nodes and handles being populated 497*523fa7a6SAndroid Build Coastguard Worker if not ((nodes is not None) ^ (handles is not None)): 498*523fa7a6SAndroid Build Coastguard Worker raise Exception( 499*523fa7a6SAndroid Build Coastguard Worker "Only one of nodes or handles must be provided. Either both were provided or neither were provided. Failed to add or update entry." 500*523fa7a6SAndroid Build Coastguard Worker ) 501*523fa7a6SAndroid Build Coastguard Worker 502*523fa7a6SAndroid Build Coastguard Worker # Resolve Identifier 503*523fa7a6SAndroid Build Coastguard Worker if identifier is None: 504*523fa7a6SAndroid Build Coastguard Worker if self._generated_identifiers: 505*523fa7a6SAndroid Build Coastguard Worker identifier = self._next_index 506*523fa7a6SAndroid Build Coastguard Worker self._next_index += 1 507*523fa7a6SAndroid Build Coastguard Worker else: 508*523fa7a6SAndroid Build Coastguard Worker raise Exception( 509*523fa7a6SAndroid Build Coastguard Worker "No identifier provided. Failed to add or update entry." 510*523fa7a6SAndroid Build Coastguard Worker ) 511*523fa7a6SAndroid Build Coastguard Worker 512*523fa7a6SAndroid Build Coastguard Worker # Collect debug handles 513*523fa7a6SAndroid Build Coastguard Worker if nodes is not None: 514*523fa7a6SAndroid Build Coastguard Worker new_debug_handles = { 515*523fa7a6SAndroid Build Coastguard Worker node.meta.get("debug_handle") 516*523fa7a6SAndroid Build Coastguard Worker for node in (nodes if isinstance(nodes, List) else [nodes]) 517*523fa7a6SAndroid Build Coastguard Worker } 518*523fa7a6SAndroid Build Coastguard Worker else: 519*523fa7a6SAndroid Build Coastguard Worker new_debug_handles = ( 520*523fa7a6SAndroid Build Coastguard Worker handles if isinstance(handles, (tuple, List)) else [handles] 521*523fa7a6SAndroid Build Coastguard Worker ) 522*523fa7a6SAndroid Build Coastguard Worker 523*523fa7a6SAndroid Build Coastguard Worker # Filter for empty debug handles 524*523fa7a6SAndroid Build Coastguard Worker filtered_debug_handles = { 525*523fa7a6SAndroid Build Coastguard Worker handle for handle in new_debug_handles if handle is not None 526*523fa7a6SAndroid Build Coastguard Worker } 527*523fa7a6SAndroid Build Coastguard Worker if len(filtered_debug_handles) == 0: 528*523fa7a6SAndroid Build Coastguard Worker raise Exception("No valid debug handles found. Failed to add entry.") 529*523fa7a6SAndroid Build Coastguard Worker 530*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore Warning from Union[int, st] keys 531*523fa7a6SAndroid Build Coastguard Worker self._debug_handle_map[identifier] = filtered_debug_handles 532*523fa7a6SAndroid Build Coastguard Worker return identifier 533*523fa7a6SAndroid Build Coastguard Worker 534*523fa7a6SAndroid Build Coastguard Worker 535*523fa7a6SAndroid Build Coastguard Workerclass WhyNoPartition: 536*523fa7a6SAndroid Build Coastguard Worker """ 537*523fa7a6SAndroid Build Coastguard Worker Simple helper class for partitioners to log why a node was not lowered. 538*523fa7a6SAndroid Build Coastguard Worker 539*523fa7a6SAndroid Build Coastguard Worker Example usage: 540*523fa7a6SAndroid Build Coastguard Worker 541*523fa7a6SAndroid Build Coastguard Worker # In your backend partitioner file(s) 542*523fa7a6SAndroid Build Coastguard Worker why = WhyNoPartition(logger=your_backend_logger) 543*523fa7a6SAndroid Build Coastguard Worker 544*523fa7a6SAndroid Build Coastguard Worker # hypothetical function that checks if a node can be lowered 545*523fa7a6SAndroid Build Coastguard Worker if not can_be_lowered(node): 546*523fa7a6SAndroid Build Coastguard Worker why(node, "This node was not lowered because ...") 547*523fa7a6SAndroid Build Coastguard Worker """ 548*523fa7a6SAndroid Build Coastguard Worker 549*523fa7a6SAndroid Build Coastguard Worker def __init__(self, logger: logging.Logger): 550*523fa7a6SAndroid Build Coastguard Worker self.logger = logger 551*523fa7a6SAndroid Build Coastguard Worker self.node: Optional[torch.fx.Node] = None 552*523fa7a6SAndroid Build Coastguard Worker self.reason: str = "" 553*523fa7a6SAndroid Build Coastguard Worker 554*523fa7a6SAndroid Build Coastguard Worker def __call__(self, node: torch.fx.Node, reason: str) -> None: 555*523fa7a6SAndroid Build Coastguard Worker self.node = node 556*523fa7a6SAndroid Build Coastguard Worker self.reason = reason 557*523fa7a6SAndroid Build Coastguard Worker self.logger.debug(self) 558*523fa7a6SAndroid Build Coastguard Worker 559*523fa7a6SAndroid Build Coastguard Worker def __str__(self) -> str: 560*523fa7a6SAndroid Build Coastguard Worker return f"WhyNoPartition: Node {self.node} was not partitioned because {self.reason}." 561