1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import torch 8from executorch.exir.pass_base import ExportPass, PassResult 9from torch.fx.experimental.proxy_tensor import make_fx 10 11 12class DecomposeEinsum(ExportPass): 13 """ 14 Decompose einsum for quantization annotation to work properly. 15 """ 16 17 def __init__(self) -> None: 18 super().__init__() 19 20 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: 21 graph = graph_module.graph 22 for node in graph.nodes: 23 if node.target == torch.ops.aten.einsum.default: 24 decomposed_module = make_fx( 25 node.target, 26 tracing_mode="fake", 27 )(node.args[0], [arg.meta["val"] for arg in node.args[1]]) 28 29 with graph.inserting_before(node): 30 # remap is used to map original node values to new node values, 31 # which ensures that reference to nodes are correclty updated in the new graph 32 remap = {} 33 # Different from other nodes, einsum args[0] is the einsum equation, 34 # while input nodes are stored in args[1] 35 for i, arg in enumerate(node.args[1]): 36 remap[f"arg1_{i+1}"] = arg 37 38 for decomposed_node in decomposed_module.graph.nodes: 39 # This is the arg[0] equation string, which is not required anymore after decomposition 40 if "arg0" in decomposed_node.name: 41 continue 42 43 # no need to copy existent 'output' 44 if decomposed_node.op == "output": 45 for user in node.users.copy(): 46 # remap 47 user.replace_input_with( 48 node, 49 remap[decomposed_node.args[0][0]], 50 ) 51 # no need to copy existent placeholders 52 elif decomposed_node.op == "placeholder": 53 # replace node map from string to graph node 54 remap[decomposed_node] = remap.pop(decomposed_node.name) 55 else: 56 remap[decomposed_node] = graph.node_copy( 57 decomposed_node, 58 arg_transform=lambda x, remap=remap: remap[x], 59 ) 60 61 graph.erase_node(node) 62 63 graph.eliminate_dead_code() 64 graph_module.recompile() 65 return PassResult(graph_module, True) 66