1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import logging 8from typing import List 9 10import torch 11 12from executorch.backends.transforms import get_shape 13from executorch.backends.transforms.addmm_mm_to_linear import ( 14 apply_addmm_mm_to_linear_transform, 15) 16from executorch.exir.dialects._ops import ops as exir_ops 17from executorch.exir.pass_base import ExportPass 18 19from torch.fx.passes.infra.pass_base import PassResult 20from torch.fx.passes.utils.source_matcher_utils import ( 21 get_source_partitions, 22 SourcePartition, 23) 24 25logger = logging.getLogger(__name__) 26logger.setLevel(logging.WARNING) 27 28 29class ConvertToLinearPass(ExportPass): 30 linear_modules = [ 31 torch.nn.Linear, 32 torch.nn.functional.linear, 33 ] 34 35 targets = [ 36 exir_ops.edge.aten.mm.default, 37 exir_ops.edge.aten.addmm.default, 38 exir_ops.edge.aten.bmm.default, 39 ] 40 41 @staticmethod 42 def find( 43 node: torch.fx.Node, 44 args: List[torch.fx.Node], 45 kind: str = "args", 46 index: int = 0, 47 ): 48 # This is a hack to support lifted graphs. 49 # TODO(T171263351) - fix source partitioning for lifted graphs 50 if not node or node in args or node.op == "placeholder": 51 return node 52 if kind == "args": 53 other = node.args[index] 54 elif kind == "users": 55 other = list(node.users.keys())[index] 56 else: 57 raise AssertionError(f"Unexpected kind: {kind}") 58 return ConvertToLinearPass.find(other, args, kind) # pyre-ignore[6] 59 60 @staticmethod 61 def get_arg(node: torch.fx.Node, arg: str): 62 if node.target == exir_ops.edge.aten.addmm.default: 63 map_ = { 64 "bias": 0, 65 "input": 1, 66 "weight": 2, 67 } 68 return node.args[map_[arg]] 69 else: 70 map_ = {"input": 0, "weight": 1} 71 return None if arg == "bias" else node.args[map_[arg]] 72 73 def find_bias_for_mm(self, src_partition: SourcePartition, mm_node: torch.fx.Node): 74 """ 75 For linear decomposed with mm + add, find bias in src partition 76 """ 77 78 mm_users = list(mm_node.users.keys()) 79 if len(mm_users) != 1: 80 return None 81 82 add_node = mm_users[0] 83 if add_node.target != exir_ops.edge.aten.add.Tensor: 84 return None 85 86 for arg in add_node.all_input_nodes: 87 if arg != mm_node and arg in src_partition.input_nodes: 88 return arg 89 90 return None 91 92 def create_linear( 93 self, 94 graph_module: torch.fx.GraphModule, 95 node: torch.fx.Node, 96 src_partition: SourcePartition, 97 ): 98 logger.debug(f"Source Partition: {src_partition}") 99 linear_input = self.find( 100 self.get_arg(node, "input"), 101 src_partition.input_nodes, 102 ) 103 logger.debug(f"Found input: {linear_input} from node {node}") 104 105 linear_weight = self.find( 106 self.get_arg(node, "weight"), 107 src_partition.input_nodes 108 + src_partition.params, # non quant weight can be in params 109 ) 110 logger.debug(f"Found weight: {linear_weight} from node {node}") 111 112 linear_bias = self.find( 113 self.get_arg(node, "bias"), 114 src_partition.input_nodes + src_partition.params, # bias can be in params 115 ) 116 if linear_bias is None and node.target == exir_ops.edge.aten.mm.default: 117 linear_bias = self.find_bias_for_mm(src_partition, node) 118 119 logger.debug(f"Found bias(?): {linear_bias} from node {node}") 120 121 # Ignore dynamic shape nodes 122 outputs = [ 123 node 124 for node in src_partition.output_nodes 125 if node.target != torch.ops.aten.sym_size.int and node.op != "placeholder" 126 ] 127 assert ( 128 len(outputs) == 1 129 ), f"Unexpected number of outputs for a torch.nn.Linear module, expecting 1 but got {outputs}" 130 output = outputs[0] 131 132 with graph_module.graph.inserting_before(output): 133 args = (linear_input, linear_weight) 134 if linear_bias is not None: 135 args += (linear_bias,) 136 linear_node = graph_module.graph.create_node( 137 "call_function", 138 exir_ops.edge.aten.linear.default, # HACK not edge_op/CATen 139 args, 140 ) 141 # TODO - calculate output even when dynamic_shape=True 142 linear_node.meta["val"] = torch.zeros(get_shape(output)) 143 logger.debug( 144 f"Replacing {output}{get_shape(output)} node with {linear_node}{get_shape(linear_node)}" 145 ) 146 output.replace_all_uses_with(linear_node) 147 graph_module.graph.eliminate_dead_code() 148 149 # override 150 def call(self, graph_module: torch.fx.GraphModule): 151 logger.debug("ConvertToLinear Begin: ") 152 logger.debug(graph_module.print_readable(print_output=False)) 153 154 processed_partitions = 0 155 while True: 156 src_partition_dict = get_source_partitions( 157 graph_module.graph, self.linear_modules 158 ) 159 160 src_node_dict = { 161 node: src_partition 162 for src_partitions in src_partition_dict.values() 163 for src_partition in src_partitions 164 for node in src_partition.nodes 165 if node.target in self.targets 166 } 167 168 # No more [add]mm target in source partitions 169 if len(src_node_dict) == 0: 170 if processed_partitions == 0: 171 logger.debug( 172 "Did not find any [add]mm target in source partitions, skipping the pass." 173 ) 174 else: 175 logger.debug( 176 f"Converted {processed_partitions} [add]mm target(s) into Linear." 177 ) 178 break 179 180 logger.debug("Converting [add]mm into Linear") 181 for node in src_node_dict.keys(): 182 self.create_linear(graph_module, node, src_node_dict[node]) 183 processed_partitions += 1 184 # Only convert the first [add]mm target 185 break 186 187 # fall back to linear transform 188 graph_module.graph = apply_addmm_mm_to_linear_transform(graph_module.graph) 189 190 graph_module.recompile() 191 192 # Propagate metadata and retrace module 193 graph_module = super().call(graph_module).graph_module 194 195 logger.debug("ConvertToLinear End: ") 196 logger.debug(graph_module.print_readable(print_output=False)) 197 198 return PassResult(graph_module, True) 199